""" Main TTS Pipeline ================= Orchestrates the complete TTS pipeline with optimization and error handling. """ import logging import time from typing import Tuple, List, Optional, Dict, Any import numpy as np from .preprocessing import TextProcessor from .model import OptimizedTTSModel from .audio_processing import AudioProcessor # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class TTSPipeline: """ High-performance TTS pipeline with advanced optimization features. This pipeline combines: - Intelligent text preprocessing and chunking - Optimized model inference with caching - Advanced audio post-processing - Comprehensive error handling and logging """ def __init__(self, model_checkpoint: str = "Edmon02/TTS_NB_2", max_chunk_length: int = 200, crossfade_duration: float = 0.1, use_mixed_precision: bool = True, device: Optional[str] = None): """ Initialize the TTS pipeline. Args: model_checkpoint: Path to the TTS model checkpoint max_chunk_length: Maximum characters per text chunk crossfade_duration: Crossfade duration between audio chunks use_mixed_precision: Whether to use mixed precision inference device: Device to use for computation """ self.model_checkpoint = model_checkpoint self.max_chunk_length = max_chunk_length self.crossfade_duration = crossfade_duration logger.info("Initializing TTS Pipeline...") # Initialize components self.text_processor = TextProcessor(max_chunk_length=max_chunk_length) self.model = OptimizedTTSModel( checkpoint=model_checkpoint, use_mixed_precision=use_mixed_precision, device=device ) self.audio_processor = AudioProcessor(crossfade_duration=crossfade_duration) # Performance tracking self.total_inferences = 0 self.total_processing_time = 0.0 # Warm up the model self._warmup() logger.info("TTS Pipeline initialized successfully") def _warmup(self): """Warm up the pipeline with a test inference.""" try: logger.info("Warming up TTS pipeline...") test_text = "Բարև ձեզ" _ = self.synthesize(test_text, log_performance=False) logger.info("Pipeline warmup completed") except Exception as e: logger.warning(f"Pipeline warmup failed: {e}") def synthesize(self, text: str, speaker: str = "BDL", enable_chunking: bool = True, apply_audio_processing: bool = True, log_performance: bool = True) -> Tuple[int, np.ndarray]: """ Main synthesis function with automatic optimization. Args: text: Input text to synthesize speaker: Speaker identifier enable_chunking: Whether to use intelligent chunking for long texts apply_audio_processing: Whether to apply audio post-processing log_performance: Whether to log performance metrics Returns: Tuple of (sample_rate, audio_array) """ start_time = time.time() try: # Validate input if not text or not text.strip(): logger.warning("Empty or invalid text provided") return 16000, np.zeros(0, dtype=np.int16) # Determine if chunking is needed should_chunk = enable_chunking and len(text) > self.max_chunk_length if should_chunk: logger.info(f"Processing long text ({len(text)} chars) with chunking") sample_rate, audio = self._synthesize_with_chunking( text, speaker, apply_audio_processing ) else: logger.debug(f"Processing short text ({len(text)} chars) directly") sample_rate, audio = self._synthesize_direct( text, speaker, apply_audio_processing ) # Track performance total_time = time.time() - start_time self.total_inferences += 1 self.total_processing_time += total_time if log_performance: audio_duration = len(audio) / sample_rate if len(audio) > 0 else 0 rtf = total_time / audio_duration if audio_duration > 0 else float('inf') logger.info( f"Synthesis completed: {len(text)} chars → " f"{audio_duration:.2f}s audio in {total_time:.3f}s " f"(RTF: {rtf:.2f})" ) return sample_rate, audio except Exception as e: logger.error(f"Synthesis failed: {e}") return 16000, np.zeros(0, dtype=np.int16) def _synthesize_direct(self, text: str, speaker: str, apply_audio_processing: bool) -> Tuple[int, np.ndarray]: """ Direct synthesis for short texts. Args: text: Input text speaker: Speaker identifier apply_audio_processing: Whether to apply post-processing Returns: Tuple of (sample_rate, audio_array) """ # Process text processed_text = self.text_processor.process_text(text) # Generate speech sample_rate, audio = self.model.generate_speech(processed_text, speaker) # Apply audio processing if requested if apply_audio_processing and len(audio) > 0: audio = self.audio_processor.process_audio(audio) audio = self.audio_processor.add_silence(audio) return sample_rate, audio def _synthesize_with_chunking(self, text: str, speaker: str, apply_audio_processing: bool) -> Tuple[int, np.ndarray]: """ Synthesis with intelligent chunking for long texts. Args: text: Input text speaker: Speaker identifier apply_audio_processing: Whether to apply post-processing Returns: Tuple of (sample_rate, audio_array) """ # Process and chunk text chunks = self.text_processor.process_chunks(text) if not chunks: logger.warning("No valid chunks generated") return 16000, np.zeros(0, dtype=np.int16) # Generate speech for all chunks sample_rate, audio = self.model.generate_speech_chunks(chunks, speaker) # Apply audio processing if requested if apply_audio_processing and len(audio) > 0: audio = self.audio_processor.process_audio(audio) audio = self.audio_processor.add_silence(audio) return sample_rate, audio def batch_synthesize(self, texts: List[str], speaker: str = "BDL", enable_chunking: bool = True) -> List[Tuple[int, np.ndarray]]: """ Batch synthesis for multiple texts. Args: texts: List of input texts speaker: Speaker identifier enable_chunking: Whether to use chunking Returns: List of (sample_rate, audio_array) tuples """ logger.info(f"Starting batch synthesis for {len(texts)} texts") results = [] for i, text in enumerate(texts): logger.debug(f"Processing batch item {i+1}/{len(texts)}") result = self.synthesize( text, speaker, enable_chunking=enable_chunking, log_performance=False ) results.append(result) logger.info(f"Batch synthesis completed: {len(results)} items processed") return results def get_performance_stats(self) -> Dict[str, Any]: """Get comprehensive performance statistics.""" stats = { "pipeline_stats": { "total_inferences": self.total_inferences, "total_processing_time": self.total_processing_time, "avg_processing_time": ( self.total_processing_time / self.total_inferences if self.total_inferences > 0 else 0 ) }, "text_processor_stats": self.text_processor.get_cache_stats(), "model_stats": self.model.get_performance_stats(), } return stats def clear_caches(self): """Clear all caches to free memory.""" self.text_processor.clear_cache() self.model.clear_performance_cache() logger.info("All caches cleared") def get_available_speakers(self) -> List[str]: """Get list of available speakers.""" return self.model.get_available_speakers() def optimize_for_production(self): """Apply production-level optimizations.""" logger.info("Applying production optimizations...") try: # Optimize model self.model.optimize_for_inference() # Clear any unnecessary caches self.clear_caches() logger.info("Production optimizations applied") except Exception as e: logger.warning(f"Some optimizations failed: {e}") def health_check(self) -> Dict[str, Any]: """ Perform a health check of the pipeline. Returns: Health status information """ health_status = { "status": "healthy", "components": {}, "timestamp": time.time() } try: # Test text processor test_text = "Թեստ տեքստ" processed = self.text_processor.process_text(test_text) health_status["components"]["text_processor"] = { "status": "ok" if processed else "error", "test_result": bool(processed) } # Test model try: _, audio = self.model.generate_speech("Բարև") health_status["components"]["model"] = { "status": "ok" if len(audio) > 0 else "error", "test_audio_samples": len(audio) } except Exception as e: health_status["components"]["model"] = { "status": "error", "error": str(e) } # Check if any component failed if any(comp.get("status") == "error" for comp in health_status["components"].values()): health_status["status"] = "degraded" except Exception as e: health_status["status"] = "error" health_status["error"] = str(e) return health_status