""" TTS Model Module ================ Handles model loading, inference optimization, and audio generation. Implements caching, mixed precision, and efficient batch processing. """ import os import logging import time from typing import Dict, List, Tuple, Optional, Union from pathlib import Path import torch import numpy as np from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan # Configure logging logger = logging.getLogger(__name__) class OptimizedTTSModel: """Optimized TTS model with caching and performance enhancements.""" def __init__(self, checkpoint: str = "Edmon02/TTS_NB_2", vocoder_checkpoint: str = "microsoft/speecht5_hifigan", device: Optional[str] = None, use_mixed_precision: bool = True, cache_embeddings: bool = True): """ Initialize the optimized TTS model. Args: checkpoint: Model checkpoint path vocoder_checkpoint: Vocoder checkpoint path device: Device to use ('cuda', 'cpu', or None for auto) use_mixed_precision: Whether to use mixed precision inference cache_embeddings: Whether to cache speaker embeddings """ self.checkpoint = checkpoint self.vocoder_checkpoint = vocoder_checkpoint self.use_mixed_precision = use_mixed_precision self.cache_embeddings = cache_embeddings # Auto-detect device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) logger.info(f"Using device: {self.device}") # Initialize components self.processor = None self.model = None self.vocoder = None self.speaker_embeddings = {} self.embedding_cache = {} # Performance tracking self.inference_times = [] # Load models self._load_models() self._load_speaker_embeddings() def _load_models(self): """Load TTS model, processor, and vocoder.""" try: logger.info("Loading TTS models...") start_time = time.time() # Load processor self.processor = SpeechT5Processor.from_pretrained(self.checkpoint) # Load main model self.model = SpeechT5ForTextToSpeech.from_pretrained(self.checkpoint) self.model.to(self.device) self.model.eval() # Set to evaluation mode # Load vocoder self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_checkpoint) self.vocoder.to(self.device) self.vocoder.eval() # Enable mixed precision if supported if self.use_mixed_precision and self.device.type == "cuda": self.model.half() self.vocoder.half() logger.info("Mixed precision enabled") load_time = time.time() - start_time logger.info(f"Models loaded in {load_time:.2f}s") except Exception as e: logger.error(f"Failed to load models: {e}") raise def _load_speaker_embeddings(self): """Load speaker embeddings from .npy files.""" try: # Define available speaker embeddings embedding_files = { "BDL": "nb_620.npy", # Add more speakers as needed } base_path = Path(__file__).parent.parent for speaker, filename in embedding_files.items(): filepath = base_path / filename if filepath.exists(): embedding = np.load(filepath).astype(np.float32) self.speaker_embeddings[speaker] = torch.tensor(embedding).to(self.device) logger.info(f"Loaded embedding for speaker {speaker}") else: logger.warning(f"Speaker embedding file not found: {filepath}") if not self.speaker_embeddings: raise FileNotFoundError("No speaker embeddings found") except Exception as e: logger.error(f"Failed to load speaker embeddings: {e}") raise def _get_speaker_embedding(self, speaker: str) -> torch.Tensor: """ Get speaker embedding with caching. Args: speaker: Speaker identifier Returns: Speaker embedding tensor """ # Extract speaker code (first 3 characters) speaker_code = speaker[:3].upper() if speaker_code not in self.speaker_embeddings: logger.warning(f"Speaker {speaker_code} not found, using default") speaker_code = list(self.speaker_embeddings.keys())[0] # Return cached embedding with batch dimension embedding = self.speaker_embeddings[speaker_code] return embedding.unsqueeze(0) # Add batch dimension def _preprocess_text(self, text: str) -> torch.Tensor: """ Preprocess text for model input. Args: text: Input text Returns: Processed input tensor """ if not text.strip(): return None # Process text inputs = self.processor(text=text, return_tensors="pt") input_ids = inputs["input_ids"].to(self.device) # Limit input length to model's maximum max_length = getattr(self.model.config, 'max_text_positions', 600) input_ids = input_ids[..., :max_length] return input_ids @torch.no_grad() def generate_speech(self, text: str, speaker: str = "BDL") -> Tuple[int, np.ndarray]: """ Generate speech from text. Args: text: Input text speaker: Speaker identifier Returns: Tuple of (sample_rate, audio_array) """ start_time = time.time() try: # Handle empty text if not text or not text.strip(): logger.warning("Empty text provided") return 16000, np.zeros(0, dtype=np.int16) # Preprocess text input_ids = self._preprocess_text(text) if input_ids is None: return 16000, np.zeros(0, dtype=np.int16) # Get speaker embedding speaker_embedding = self._get_speaker_embedding(speaker) # Generate speech with mixed precision if enabled if self.use_mixed_precision and self.device.type == "cuda": with torch.cuda.amp.autocast(): speech = self.model.generate_speech( input_ids, speaker_embedding, vocoder=self.vocoder ) else: speech = self.model.generate_speech( input_ids, speaker_embedding, vocoder=self.vocoder ) # Convert to numpy and scale to int16 speech_np = speech.cpu().numpy() speech_int16 = (speech_np * 32767).astype(np.int16) # Track performance inference_time = time.time() - start_time self.inference_times.append(inference_time) logger.info(f"Generated {len(speech_int16)} samples in {inference_time:.3f}s") return 16000, speech_int16 except Exception as e: logger.error(f"Speech generation failed: {e}") return 16000, np.zeros(0, dtype=np.int16) def generate_speech_chunks(self, text_chunks: List[str], speaker: str = "BDL") -> Tuple[int, np.ndarray]: """ Generate speech from multiple text chunks and concatenate. Args: text_chunks: List of text chunks speaker: Speaker identifier Returns: Tuple of (sample_rate, concatenated_audio_array) """ if not text_chunks: return 16000, np.zeros(0, dtype=np.int16) logger.info(f"Generating speech for {len(text_chunks)} chunks") audio_segments = [] total_start_time = time.time() for i, chunk in enumerate(text_chunks): logger.debug(f"Processing chunk {i+1}/{len(text_chunks)}") sample_rate, audio = self.generate_speech(chunk, speaker) if len(audio) > 0: audio_segments.append(audio) if not audio_segments: logger.warning("No audio generated from chunks") return 16000, np.zeros(0, dtype=np.int16) # Concatenate all audio segments concatenated_audio = np.concatenate(audio_segments) total_time = time.time() - total_start_time logger.info(f"Generated {len(concatenated_audio)} samples from {len(text_chunks)} chunks in {total_time:.3f}s") return 16000, concatenated_audio def batch_generate_speech(self, texts: List[str], speaker: str = "BDL") -> List[Tuple[int, np.ndarray]]: """ Generate speech for multiple texts (batch processing). Args: texts: List of input texts speaker: Speaker identifier Returns: List of (sample_rate, audio_array) tuples """ results = [] for text in texts: result = self.generate_speech(text, speaker) results.append(result) return results def get_performance_stats(self) -> Dict[str, float]: """Get performance statistics.""" if not self.inference_times: return {"avg_inference_time": 0.0, "total_inferences": 0} return { "avg_inference_time": np.mean(self.inference_times), "min_inference_time": np.min(self.inference_times), "max_inference_time": np.max(self.inference_times), "total_inferences": len(self.inference_times) } def clear_performance_cache(self): """Clear performance tracking data.""" self.inference_times.clear() logger.info("Performance cache cleared") def get_available_speakers(self) -> List[str]: """Get list of available speakers.""" return list(self.speaker_embeddings.keys()) def optimize_for_inference(self): """Apply additional optimizations for inference.""" try: if hasattr(torch.backends, 'cudnn'): torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False # Compile model for better performance (PyTorch 2.0+) if hasattr(torch, 'compile') and self.device.type == "cuda": logger.info("Compiling model for optimization...") self.model = torch.compile(self.model) self.vocoder = torch.compile(self.vocoder) logger.info("Model optimization completed") except Exception as e: logger.warning(f"Model optimization failed: {e}") def warmup(self, warmup_text: str = "Բարև ձեզ"): """ Warm up the model with a simple inference. Args: warmup_text: Text to use for warmup """ logger.info("Warming up model...") try: _ = self.generate_speech(warmup_text) logger.info("Model warmup completed") except Exception as e: logger.warning(f"Model warmup failed: {e}")