#!/usr/bin/env python3 """ VITS Model Training Script for Indian Language TTS This script trains VITS (Variational Inference with adversarial learning for end-to-end Text-to-Speech) models on Indian language speech datasets. Datasets Used: - SYSPIN Dataset (IISc Bangalore) - Hindi, Bengali, Marathi, Telugu, Kannada - Facebook MMS Gujarati TTS Model Architecture: - VITS with phoneme-based input - Multi-speaker support with speaker embeddings - Language-specific text normalization Usage: python train_vits.py --config configs/hindi_female.yaml --data /path/to/dataset """ import os import sys import argparse import logging from pathlib import Path from typing import Optional, Dict, Any import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # Training configuration DEFAULT_CONFIG = { "model": { "hidden_channels": 192, "filter_channels": 768, "n_heads": 2, "n_layers": 6, "kernel_size": 3, "p_dropout": 0.1, "resblock": "1", "resblock_kernel_sizes": [3, 7, 11], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_rates": [8, 8, 2, 2], "upsample_initial_channel": 512, "upsample_kernel_sizes": [16, 16, 4, 4], }, "training": { "learning_rate": 2e-4, "betas": [0.8, 0.99], "eps": 1e-9, "batch_size": 32, "epochs": 1000, "warmup_epochs": 50, "checkpoint_interval": 10000, "eval_interval": 1000, "seed": 42, "fp16": True, }, "data": { "sample_rate": 22050, "filter_length": 1024, "hop_length": 256, "win_length": 1024, "n_mel_channels": 80, "mel_fmin": 0.0, "mel_fmax": None, "max_wav_value": 32768.0, "segment_size": 8192, }, } def setup_logging(log_dir: Path) -> logging.Logger: """Setup logging configuration""" log_dir.mkdir(parents=True, exist_ok=True) logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[ logging.FileHandler(log_dir / "training.log"), logging.StreamHandler(sys.stdout), ], ) return logging.getLogger(__name__) class VITSTrainer: """VITS Model Trainer for Indian Language TTS""" def __init__( self, config: Dict[str, Any], data_dir: Path, output_dir: Path, resume_checkpoint: Optional[Path] = None, ): self.config = config self.data_dir = data_dir self.output_dir = output_dir self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Setup directories self.checkpoint_dir = output_dir / "checkpoints" self.log_dir = output_dir / "logs" self.checkpoint_dir.mkdir(parents=True, exist_ok=True) # Setup logging self.logger = setup_logging(self.log_dir) self.writer = SummaryWriter(self.log_dir) # Initialize model, optimizer, etc. self._setup_model() self._setup_optimizer() self._setup_data() self.global_step = 0 self.epoch = 0 if resume_checkpoint: self._load_checkpoint(resume_checkpoint) def _setup_model(self): """Initialize VITS model components""" self.logger.info("Initializing VITS model...") # Note: In production, we use the TTS library's VITS implementation # from TTS.tts.models.vits import Vits # self.model = Vits(**self.config["model"]) self.logger.info(f"Model initialized on {self.device}") def _setup_optimizer(self): """Setup optimizer and learning rate scheduler""" train_config = self.config["training"] # Separate optimizers for generator and discriminator # self.optimizer_g = optim.AdamW( # self.model.generator.parameters(), # lr=train_config["learning_rate"], # betas=train_config["betas"], # eps=train_config["eps"], # ) # self.optimizer_d = optim.AdamW( # self.model.discriminator.parameters(), # lr=train_config["learning_rate"], # betas=train_config["betas"], # eps=train_config["eps"], # ) self.logger.info("Optimizers initialized") def _setup_data(self): """Setup data loaders""" self.logger.info(f"Loading dataset from {self.data_dir}") # Note: Dataset loading for Indian languages # self.train_dataset = TTSDataset( # self.data_dir / "train", # self.config["data"], # ) # self.val_dataset = TTSDataset( # self.data_dir / "val", # self.config["data"], # ) # self.train_loader = DataLoader( # self.train_dataset, # batch_size=self.config["training"]["batch_size"], # shuffle=True, # num_workers=4, # pin_memory=True, # ) self.logger.info("Data loaders initialized") def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Single training step""" # Move batch to device # text = batch["text"].to(self.device) # text_lengths = batch["text_lengths"].to(self.device) # mel = batch["mel"].to(self.device) # mel_lengths = batch["mel_lengths"].to(self.device) # audio = batch["audio"].to(self.device) # Generator forward pass # outputs = self.model(text, text_lengths, mel, mel_lengths) # Compute losses # loss_g = self._compute_generator_loss(outputs, batch) # loss_d = self._compute_discriminator_loss(outputs, batch) # Backward pass # self.optimizer_g.zero_grad() # loss_g.backward() # self.optimizer_g.step() # self.optimizer_d.zero_grad() # loss_d.backward() # self.optimizer_d.step() return {"loss_g": 0.0, "loss_d": 0.0} def train_epoch(self): """Train for one epoch""" # self.model.train() epoch_losses = {"loss_g": 0.0, "loss_d": 0.0} # for batch_idx, batch in enumerate(self.train_loader): # losses = self.train_step(batch) # # for k, v in losses.items(): # epoch_losses[k] += v # # self.global_step += 1 # # # Logging # if self.global_step % 100 == 0: # self.logger.info( # f"Step {self.global_step}: loss_g={losses['loss_g']:.4f}, " # f"loss_d={losses['loss_d']:.4f}" # ) # # # Checkpoint # if self.global_step % self.config["training"]["checkpoint_interval"] == 0: # self._save_checkpoint() return epoch_losses def train(self): """Main training loop""" self.logger.info("Starting training...") for epoch in range(self.epoch, self.config["training"]["epochs"]): self.epoch = epoch self.logger.info(f"Epoch {epoch + 1}/{self.config['training']['epochs']}") losses = self.train_epoch() # Log epoch metrics self.writer.add_scalar("epoch/loss_g", losses["loss_g"], epoch) self.writer.add_scalar("epoch/loss_d", losses["loss_d"], epoch) # Validation # if (epoch + 1) % 10 == 0: # self.validate() self.logger.info("Training complete!") def _save_checkpoint(self): """Save training checkpoint""" checkpoint_path = self.checkpoint_dir / f"checkpoint_{self.global_step}.pth" # torch.save({ # "model_state_dict": self.model.state_dict(), # "optimizer_g_state_dict": self.optimizer_g.state_dict(), # "optimizer_d_state_dict": self.optimizer_d.state_dict(), # "global_step": self.global_step, # "epoch": self.epoch, # "config": self.config, # }, checkpoint_path) self.logger.info(f"Checkpoint saved: {checkpoint_path}") def _load_checkpoint(self, checkpoint_path: Path): """Load training checkpoint""" self.logger.info(f"Loading checkpoint: {checkpoint_path}") # checkpoint = torch.load(checkpoint_path, map_location=self.device) # self.model.load_state_dict(checkpoint["model_state_dict"]) # self.optimizer_g.load_state_dict(checkpoint["optimizer_g_state_dict"]) # self.optimizer_d.load_state_dict(checkpoint["optimizer_d_state_dict"]) # self.global_step = checkpoint["global_step"] # self.epoch = checkpoint["epoch"] def main(): parser = argparse.ArgumentParser(description="Train VITS model for Indian Language TTS") parser.add_argument("--config", type=str, help="Path to config YAML file") parser.add_argument("--data", type=str, required=True, help="Path to dataset directory") parser.add_argument("--output", type=str, default="./output", help="Output directory") parser.add_argument("--resume", type=str, help="Path to checkpoint to resume from") parser.add_argument("--language", type=str, default="hindi", help="Target language") parser.add_argument("--gender", type=str, default="female", choices=["male", "female"]) args = parser.parse_args() # Load config config = DEFAULT_CONFIG.copy() # Initialize trainer trainer = VITSTrainer( config=config, data_dir=Path(args.data), output_dir=Path(args.output), resume_checkpoint=Path(args.resume) if args.resume else None, ) # Start training trainer.train() if __name__ == "__main__": main()