Spaces:
Runtime error
Runtime error
""" | |
TTS Model Module | |
================ | |
Handles model loading, inference optimization, and audio generation. | |
Implements caching, mixed precision, and efficient batch processing. | |
""" | |
import os | |
import logging | |
import time | |
from typing import Dict, List, Tuple, Optional, Union | |
from pathlib import Path | |
import torch | |
import numpy as np | |
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
# Configure logging | |
logger = logging.getLogger(__name__) | |
class OptimizedTTSModel: | |
"""Optimized TTS model with caching and performance enhancements.""" | |
def __init__(self, | |
checkpoint: str = "Edmon02/TTS_NB_2", | |
vocoder_checkpoint: str = "microsoft/speecht5_hifigan", | |
device: Optional[str] = None, | |
use_mixed_precision: bool = True, | |
cache_embeddings: bool = True): | |
""" | |
Initialize the optimized TTS model. | |
Args: | |
checkpoint: Model checkpoint path | |
vocoder_checkpoint: Vocoder checkpoint path | |
device: Device to use ('cuda', 'cpu', or None for auto) | |
use_mixed_precision: Whether to use mixed precision inference | |
cache_embeddings: Whether to cache speaker embeddings | |
""" | |
self.checkpoint = checkpoint | |
self.vocoder_checkpoint = vocoder_checkpoint | |
self.use_mixed_precision = use_mixed_precision | |
self.cache_embeddings = cache_embeddings | |
# Auto-detect device | |
if device is None: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
else: | |
self.device = torch.device(device) | |
logger.info(f"Using device: {self.device}") | |
# Initialize components | |
self.processor = None | |
self.model = None | |
self.vocoder = None | |
self.speaker_embeddings = {} | |
self.embedding_cache = {} | |
# Performance tracking | |
self.inference_times = [] | |
# Load models | |
self._load_models() | |
self._load_speaker_embeddings() | |
def _load_models(self): | |
"""Load TTS model, processor, and vocoder.""" | |
try: | |
logger.info("Loading TTS models...") | |
start_time = time.time() | |
# Load processor | |
self.processor = SpeechT5Processor.from_pretrained(self.checkpoint) | |
# Load main model | |
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.checkpoint) | |
self.model.to(self.device) | |
self.model.eval() # Set to evaluation mode | |
# Load vocoder | |
self.vocoder = SpeechT5HifiGan.from_pretrained(self.vocoder_checkpoint) | |
self.vocoder.to(self.device) | |
self.vocoder.eval() | |
# Enable mixed precision if supported | |
if self.use_mixed_precision and self.device.type == "cuda": | |
self.model.half() | |
self.vocoder.half() | |
logger.info("Mixed precision enabled") | |
load_time = time.time() - start_time | |
logger.info(f"Models loaded in {load_time:.2f}s") | |
except Exception as e: | |
logger.error(f"Failed to load models: {e}") | |
raise | |
def _load_speaker_embeddings(self): | |
"""Load speaker embeddings from .npy files.""" | |
try: | |
# Define available speaker embeddings | |
embedding_files = { | |
"BDL": "nb_620.npy", | |
# Add more speakers as needed | |
} | |
base_path = Path(__file__).parent.parent | |
for speaker, filename in embedding_files.items(): | |
filepath = base_path / filename | |
if filepath.exists(): | |
embedding = np.load(filepath).astype(np.float32) | |
self.speaker_embeddings[speaker] = torch.tensor(embedding).to(self.device) | |
logger.info(f"Loaded embedding for speaker {speaker}") | |
else: | |
logger.warning(f"Speaker embedding file not found: {filepath}") | |
if not self.speaker_embeddings: | |
raise FileNotFoundError("No speaker embeddings found") | |
except Exception as e: | |
logger.error(f"Failed to load speaker embeddings: {e}") | |
raise | |
def _get_speaker_embedding(self, speaker: str) -> torch.Tensor: | |
""" | |
Get speaker embedding with caching. | |
Args: | |
speaker: Speaker identifier | |
Returns: | |
Speaker embedding tensor | |
""" | |
# Extract speaker code (first 3 characters) | |
speaker_code = speaker[:3].upper() | |
if speaker_code not in self.speaker_embeddings: | |
logger.warning(f"Speaker {speaker_code} not found, using default") | |
speaker_code = list(self.speaker_embeddings.keys())[0] | |
# Return cached embedding with batch dimension | |
embedding = self.speaker_embeddings[speaker_code] | |
return embedding.unsqueeze(0) # Add batch dimension | |
def _preprocess_text(self, text: str) -> torch.Tensor: | |
""" | |
Preprocess text for model input. | |
Args: | |
text: Input text | |
Returns: | |
Processed input tensor | |
""" | |
if not text.strip(): | |
return None | |
# Process text | |
inputs = self.processor(text=text, return_tensors="pt") | |
input_ids = inputs["input_ids"].to(self.device) | |
# Limit input length to model's maximum | |
max_length = getattr(self.model.config, 'max_text_positions', 600) | |
input_ids = input_ids[..., :max_length] | |
return input_ids | |
def generate_speech(self, text: str, speaker: str = "BDL") -> Tuple[int, np.ndarray]: | |
""" | |
Generate speech from text. | |
Args: | |
text: Input text | |
speaker: Speaker identifier | |
Returns: | |
Tuple of (sample_rate, audio_array) | |
""" | |
start_time = time.time() | |
try: | |
# Handle empty text | |
if not text or not text.strip(): | |
logger.warning("Empty text provided") | |
return 16000, np.zeros(0, dtype=np.int16) | |
# Preprocess text | |
input_ids = self._preprocess_text(text) | |
if input_ids is None: | |
return 16000, np.zeros(0, dtype=np.int16) | |
# Get speaker embedding | |
speaker_embedding = self._get_speaker_embedding(speaker) | |
# Generate speech with mixed precision if enabled | |
if self.use_mixed_precision and self.device.type == "cuda": | |
with torch.cuda.amp.autocast(): | |
speech = self.model.generate_speech( | |
input_ids, | |
speaker_embedding, | |
vocoder=self.vocoder | |
) | |
else: | |
speech = self.model.generate_speech( | |
input_ids, | |
speaker_embedding, | |
vocoder=self.vocoder | |
) | |
# Convert to numpy and scale to int16 | |
speech_np = speech.cpu().numpy() | |
speech_int16 = (speech_np * 32767).astype(np.int16) | |
# Track performance | |
inference_time = time.time() - start_time | |
self.inference_times.append(inference_time) | |
logger.info(f"Generated {len(speech_int16)} samples in {inference_time:.3f}s") | |
return 16000, speech_int16 | |
except Exception as e: | |
logger.error(f"Speech generation failed: {e}") | |
return 16000, np.zeros(0, dtype=np.int16) | |
def generate_speech_chunks(self, text_chunks: List[str], speaker: str = "BDL") -> Tuple[int, np.ndarray]: | |
""" | |
Generate speech from multiple text chunks and concatenate. | |
Args: | |
text_chunks: List of text chunks | |
speaker: Speaker identifier | |
Returns: | |
Tuple of (sample_rate, concatenated_audio_array) | |
""" | |
if not text_chunks: | |
return 16000, np.zeros(0, dtype=np.int16) | |
logger.info(f"Generating speech for {len(text_chunks)} chunks") | |
audio_segments = [] | |
total_start_time = time.time() | |
for i, chunk in enumerate(text_chunks): | |
logger.debug(f"Processing chunk {i+1}/{len(text_chunks)}") | |
sample_rate, audio = self.generate_speech(chunk, speaker) | |
if len(audio) > 0: | |
audio_segments.append(audio) | |
if not audio_segments: | |
logger.warning("No audio generated from chunks") | |
return 16000, np.zeros(0, dtype=np.int16) | |
# Concatenate all audio segments | |
concatenated_audio = np.concatenate(audio_segments) | |
total_time = time.time() - total_start_time | |
logger.info(f"Generated {len(concatenated_audio)} samples from {len(text_chunks)} chunks in {total_time:.3f}s") | |
return 16000, concatenated_audio | |
def batch_generate_speech(self, texts: List[str], speaker: str = "BDL") -> List[Tuple[int, np.ndarray]]: | |
""" | |
Generate speech for multiple texts (batch processing). | |
Args: | |
texts: List of input texts | |
speaker: Speaker identifier | |
Returns: | |
List of (sample_rate, audio_array) tuples | |
""" | |
results = [] | |
for text in texts: | |
result = self.generate_speech(text, speaker) | |
results.append(result) | |
return results | |
def get_performance_stats(self) -> Dict[str, float]: | |
"""Get performance statistics.""" | |
if not self.inference_times: | |
return {"avg_inference_time": 0.0, "total_inferences": 0} | |
return { | |
"avg_inference_time": np.mean(self.inference_times), | |
"min_inference_time": np.min(self.inference_times), | |
"max_inference_time": np.max(self.inference_times), | |
"total_inferences": len(self.inference_times) | |
} | |
def clear_performance_cache(self): | |
"""Clear performance tracking data.""" | |
self.inference_times.clear() | |
logger.info("Performance cache cleared") | |
def get_available_speakers(self) -> List[str]: | |
"""Get list of available speakers.""" | |
return list(self.speaker_embeddings.keys()) | |
def optimize_for_inference(self): | |
"""Apply additional optimizations for inference.""" | |
try: | |
if hasattr(torch.backends, 'cudnn'): | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = False | |
# Compile model for better performance (PyTorch 2.0+) | |
if hasattr(torch, 'compile') and self.device.type == "cuda": | |
logger.info("Compiling model for optimization...") | |
self.model = torch.compile(self.model) | |
self.vocoder = torch.compile(self.vocoder) | |
logger.info("Model optimization completed") | |
except Exception as e: | |
logger.warning(f"Model optimization failed: {e}") | |
def warmup(self, warmup_text: str = "Բարև ձեզ"): | |
""" | |
Warm up the model with a simple inference. | |
Args: | |
warmup_text: Text to use for warmup | |
""" | |
logger.info("Warming up model...") | |
try: | |
_ = self.generate_speech(warmup_text) | |
logger.info("Model warmup completed") | |
except Exception as e: | |
logger.warning(f"Model warmup failed: {e}") | |