""" TTS Engine for Multi-lingual Indian Language Speech Synthesis This engine uses VITS (Variational Inference with adversarial learning for Text-to-Speech) models trained on various Indian language datasets. Supported Languages: - Hindi, Bengali, Marathi, Telugu, Kannada - Gujarati (via Facebook MMS), Bhojpuri, Chhattisgarhi - Maithili, Magahi, English Model Types: - JIT traced models (.pt) - Trained using train_vits.py - Coqui TTS checkpoints (.pth) - For Bhojpuri - Facebook MMS - For Gujarati """ import os import logging from pathlib import Path from typing import Dict, Optional, Union, List, Tuple, Any import numpy as np import torch from dataclasses import dataclass from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer from .model_loader import _ensure_models_available, get_model_path, list_available_models logger = logging.getLogger(__name__) @dataclass class TTSOutput: """Output from TTS synthesis""" audio: np.ndarray sample_rate: int duration: float voice: str text: str style: Optional[str] = None class StyleProcessor: """ Prosody/style control via audio post-processing Supports pitch shifting, speed change, and energy modification """ @staticmethod def apply_pitch_shift(audio: np.ndarray, sample_rate: int, pitch_factor: float) -> np.ndarray: """Shift pitch without changing duration""" if pitch_factor == 1.0: return audio try: import librosa semitones = 12 * np.log2(pitch_factor) shifted = librosa.effects.pitch_shift( audio.astype(np.float32), sr=sample_rate, n_steps=semitones ) return shifted except ImportError: from scipy import signal stretched = signal.resample(audio, int(len(audio) / pitch_factor)) return signal.resample(stretched, len(audio)) @staticmethod def apply_speed_change(audio: np.ndarray, sample_rate: int, speed_factor: float) -> np.ndarray: """Change speed/tempo without changing pitch""" if speed_factor == 1.0: return audio try: import librosa stretched = librosa.effects.time_stretch( audio.astype(np.float32), rate=speed_factor ) return stretched except ImportError: from scipy import signal target_length = int(len(audio) / speed_factor) return signal.resample(audio, target_length) @staticmethod def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray: """Modify audio energy/volume""" if energy_factor == 1.0: return audio modified = audio * energy_factor if energy_factor > 1.0: max_val = np.max(np.abs(modified)) if max_val > 0.95: modified = np.tanh(modified * 2) * 0.95 return modified @staticmethod def apply_style( audio: np.ndarray, sample_rate: int, speed: float = 1.0, pitch: float = 1.0, energy: float = 1.0, ) -> np.ndarray: """Apply all style modifications""" result = audio if pitch != 1.0: result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch) if speed != 1.0: result = StyleProcessor.apply_speed_change(result, sample_rate, speed) if energy != 1.0: result = StyleProcessor.apply_energy_change(result, energy) return result @staticmethod def get_preset(preset_name: str) -> Dict[str, float]: """Get style parameters from preset name""" return STYLE_PRESETS.get(preset_name, STYLE_PRESETS["default"]) class TTSEngine: """ Multi-lingual TTS Engine using trained VITS models Supports 11 Indian languages with male/female voices. Models are loaded from the models/ directory which contains trained checkpoints exported using training/export_model.py. """ def __init__( self, models_dir: str = MODELS_DIR, device: str = "auto", preload_voices: Optional[List[str]] = None, ): """ Initialize TTS Engine Args: models_dir: Directory containing trained models device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto') preload_voices: List of voice keys to preload into memory """ self.models_dir = Path(models_dir) self.device = self._get_device(device) # Ensure models are available _ensure_models_available() # Model caches self._models: Dict[str, torch.jit.ScriptModule] = {} self._tokenizers: Dict[str, TTSTokenizer] = {} self._coqui_models: Dict[str, Any] = {} self._mms_models: Dict[str, Any] = {} self._mms_tokenizers: Dict[str, Any] = {} # Text normalizer self.normalizer = TextNormalizer() # Style processor self.style_processor = StyleProcessor() # Preload specified voices if preload_voices: for voice in preload_voices: self.load_voice(voice) logger.info(f"TTS Engine initialized on device: {self.device}") def _get_device(self, device: str) -> torch.device: """Determine the best device for inference""" if device == "auto": if torch.cuda.is_available(): return torch.device("cuda") else: return torch.device("cpu") return torch.device(device) def load_voice(self, voice_key: str) -> bool: """ Load a trained voice model into memory Args: voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male') Returns: True if loaded successfully """ if voice_key in self._models or voice_key in self._coqui_models: return True if voice_key not in LANGUAGE_CONFIGS: raise ValueError(f"Unknown voice: {voice_key}") config = LANGUAGE_CONFIGS[voice_key] model_dir = self.models_dir / voice_key if not model_dir.exists(): raise FileNotFoundError(f"Model not found: {model_dir}") # Check model type pth_files = list(model_dir.glob("*.pth")) pt_files = list(model_dir.glob("*.pt")) if pth_files: return self._load_coqui_voice(voice_key, model_dir, pth_files[0]) elif pt_files: return self._load_jit_voice(voice_key, model_dir, pt_files[0]) else: raise FileNotFoundError(f"No model file found in {model_dir}") def _load_jit_voice(self, voice_key: str, model_dir: Path, model_path: Path) -> bool: """Load a JIT traced VITS model""" chars_path = model_dir / "chars.txt" if chars_path.exists(): tokenizer = TTSTokenizer.from_chars_file(str(chars_path)) else: chars_files = list(model_dir.glob("*chars*.txt")) if chars_files: tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0])) else: raise FileNotFoundError(f"No chars.txt found in {model_dir}") logger.info(f"Loading model from {model_path}") model = torch.jit.load(str(model_path), map_location=self.device) model.eval() self._models[voice_key] = model self._tokenizers[voice_key] = tokenizer logger.info(f"Loaded voice: {voice_key}") return True def _load_coqui_voice(self, voice_key: str, model_dir: Path, checkpoint_path: Path) -> bool: """Load a Coqui TTS checkpoint model""" config_path = model_dir / "config.json" if not config_path.exists(): raise FileNotFoundError(f"No config.json found in {model_dir}") try: from TTS.utils.synthesizer import Synthesizer logger.info(f"Loading checkpoint from {checkpoint_path}") use_cuda = self.device.type == "cuda" synthesizer = Synthesizer( tts_checkpoint=str(checkpoint_path), tts_config_path=str(config_path), use_cuda=use_cuda, ) self._coqui_models[voice_key] = synthesizer logger.info(f"Loaded voice: {voice_key}") return True except ImportError: raise ImportError("Coqui TTS library not installed.") def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]: """Synthesize using Coqui TTS model""" if voice_key not in self._coqui_models: self.load_voice(voice_key) synthesizer = self._coqui_models[voice_key] wav = synthesizer.tts(text) audio_np = np.array(wav, dtype=np.float32) sample_rate = synthesizer.output_sample_rate return audio_np, sample_rate def _load_mms_voice(self, voice_key: str) -> bool: """Load Facebook MMS model for Gujarati""" if voice_key in self._mms_models: return True config = LANGUAGE_CONFIGS[voice_key] logger.info(f"Loading MMS model: {config.hf_model_id}") try: from transformers import VitsModel, AutoTokenizer model = VitsModel.from_pretrained(config.hf_model_id) tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id) model = model.to(self.device) model.eval() self._mms_models[voice_key] = model self._mms_tokenizers[voice_key] = tokenizer logger.info(f"Loaded MMS voice: {voice_key}") return True except Exception as e: logger.error(f"Failed to load MMS model: {e}") raise def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]: """Synthesize using Facebook MMS model""" if voice_key not in self._mms_models: self._load_mms_voice(voice_key) model = self._mms_models[voice_key] tokenizer = self._mms_tokenizers[voice_key] config = LANGUAGE_CONFIGS[voice_key] inputs = tokenizer(text, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): output = model(**inputs) audio = output.waveform.squeeze().cpu().numpy() return audio, config.sample_rate def unload_voice(self, voice_key: str): """Unload a voice to free memory""" if voice_key in self._models: del self._models[voice_key] del self._tokenizers[voice_key] if voice_key in self._coqui_models: del self._coqui_models[voice_key] if voice_key in self._mms_models: del self._mms_models[voice_key] del self._mms_tokenizers[voice_key] torch.cuda.empty_cache() if self.device.type == "cuda" else None logger.info(f"Unloaded voice: {voice_key}") def synthesize( self, text: str, voice: str = "hi_male", speed: float = 1.0, pitch: float = 1.0, energy: float = 1.0, style: Optional[str] = None, normalize_text: bool = True, ) -> TTSOutput: """ Synthesize speech from text Args: text: Input text to synthesize voice: Voice key (e.g., 'hi_male', 'bn_female') speed: Speech speed multiplier (0.5-2.0) pitch: Pitch multiplier (0.5-2.0) energy: Energy/volume multiplier (0.5-2.0) style: Style preset name (e.g., 'happy', 'sad') normalize_text: Whether to apply text normalization Returns: TTSOutput with audio array and metadata """ if style and style in STYLE_PRESETS: preset = STYLE_PRESETS[style] speed = speed * preset["speed"] pitch = pitch * preset["pitch"] energy = energy * preset["energy"] config = LANGUAGE_CONFIGS[voice] if normalize_text: text = self.normalizer.clean_text(text, config.code) # Route to appropriate model type if "mms" in voice: audio_np, sample_rate = self._synthesize_mms(text, voice) elif voice in self._coqui_models: audio_np, sample_rate = self._synthesize_coqui(text, voice) else: if voice not in self._models and voice not in self._coqui_models: self.load_voice(voice) if voice in self._coqui_models: audio_np, sample_rate = self._synthesize_coqui(text, voice) else: model = self._models[voice] tokenizer = self._tokenizers[voice] token_ids = tokenizer.text_to_ids(text) x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device) with torch.no_grad(): audio = model(x) audio_np = audio.squeeze().cpu().numpy() sample_rate = config.sample_rate # Apply style modifications audio_np = self.style_processor.apply_style( audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy ) duration = len(audio_np) / sample_rate return TTSOutput( audio=audio_np, sample_rate=sample_rate, duration=duration, voice=voice, text=text, style=style, ) def synthesize_to_file( self, text: str, output_path: str, voice: str = "hi_male", speed: float = 1.0, pitch: float = 1.0, energy: float = 1.0, style: Optional[str] = None, normalize_text: bool = True, ) -> str: """Synthesize speech and save to file""" import soundfile as sf output = self.synthesize(text, voice, speed, pitch, energy, style, normalize_text) sf.write(output_path, output.audio, output.sample_rate) logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)") return output_path def get_loaded_voices(self) -> List[str]: """Get list of currently loaded voices""" return ( list(self._models.keys()) + list(self._coqui_models.keys()) + list(self._mms_models.keys()) ) def get_available_voices(self) -> Dict[str, Dict]: """Get all available voices with their status""" voices = {} for key, config in LANGUAGE_CONFIGS.items(): is_mms = "mms" in key model_dir = self.models_dir / key if is_mms: model_type = "mms" elif model_dir.exists() and list(model_dir.glob("*.pth")): model_type = "coqui" else: model_type = "vits" voices[key] = { "name": config.name, "code": config.code, "gender": "male" if "male" in key else ("female" if "female" in key else "neutral"), "loaded": key in self._models or key in self._coqui_models or key in self._mms_models, "downloaded": is_mms or get_model_path(key) is not None, "type": model_type, } return voices def get_style_presets(self) -> Dict[str, Dict]: """Get available style presets""" return STYLE_PRESETS def batch_synthesize(self, texts: List[str], voice: str = "hi_male", speed: float = 1.0) -> List[TTSOutput]: """Synthesize multiple texts""" return [self.synthesize(text, voice, speed) for text in texts] def synthesize(text: str, voice: str = "hi_male", output_path: Optional[str] = None) -> Union[TTSOutput, str]: """Quick synthesis function""" engine = TTSEngine() if output_path: return engine.synthesize_to_file(text, output_path, voice) return engine.synthesize(text, voice)