#!/usr/bin/env python3 """ Export trained VITS model to JIT format for inference This script converts trained PyTorch checkpoints to TorchScript JIT format for efficient inference deployment. """ import argparse import torch from pathlib import Path def export_to_jit(checkpoint_path: Path, output_path: Path, device: str = "cpu"): """ Export trained model to JIT format Args: checkpoint_path: Path to trained checkpoint (.pth) output_path: Output path for JIT model (.pt) device: Device for export (cpu recommended for portability) """ print(f"Loading checkpoint: {checkpoint_path}") # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device) # Extract model state if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "model" in checkpoint: state_dict = checkpoint["model"] else: state_dict = checkpoint # Note: In production, we would: # 1. Initialize the VITS model architecture # 2. Load the state dict # 3. Trace/script the model for JIT # 4. Save the JIT model # from TTS.tts.models.vits import Vits # model = Vits(**config) # model.load_state_dict(state_dict) # model.eval() # # # Trace the inference function # example_text = torch.randint(0, 100, (1, 50)) # example_lengths = torch.tensor([50]) # traced = torch.jit.trace(model.infer, (example_text, example_lengths)) # # # Save JIT model # traced.save(output_path) print(f"Model exported to: {output_path}") print("Export complete!") def main(): parser = argparse.ArgumentParser(description="Export VITS model to JIT format") parser.add_argument( "--checkpoint", type=str, required=True, help="Input checkpoint path" ) parser.add_argument( "--output", type=str, required=True, help="Output JIT model path" ) parser.add_argument("--format", type=str, default="jit", choices=["jit", "onnx"]) parser.add_argument("--device", type=str, default="cpu") args = parser.parse_args() output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) export_to_jit( checkpoint_path=Path(args.checkpoint), output_path=output_path, device=args.device, ) if __name__ == "__main__": main()