Spaces:
Runtime error
Runtime error
""" | |
Configuration Module for TTS Pipeline | |
===================================== | |
Centralized configuration management for all pipeline components. | |
""" | |
import os | |
from dataclasses import dataclass | |
from typing import Optional, Dict, Any | |
import torch | |
class TextProcessingConfig: | |
"""Configuration for text processing components.""" | |
max_chunk_length: int = 200 | |
overlap_words: int = 5 | |
translation_timeout: int = 10 | |
enable_caching: bool = True | |
cache_size: int = 1000 | |
class ModelConfig: | |
"""Configuration for TTS model components.""" | |
checkpoint: str = "Edmon02/TTS_NB_2" | |
vocoder_checkpoint: str = "microsoft/speecht5_hifigan" | |
device: Optional[str] = None | |
use_mixed_precision: bool = True | |
cache_embeddings: bool = True | |
max_text_positions: int = 600 | |
class AudioProcessingConfig: | |
"""Configuration for audio processing components.""" | |
crossfade_duration: float = 0.1 | |
sample_rate: int = 16000 | |
apply_noise_gate: bool = True | |
normalize_audio: bool = True | |
noise_gate_threshold_db: float = -40.0 | |
target_peak: float = 0.95 | |
class PipelineConfig: | |
"""Main pipeline configuration.""" | |
enable_chunking: bool = True | |
apply_audio_processing: bool = True | |
enable_performance_tracking: bool = True | |
max_concurrent_requests: int = 5 | |
warmup_on_init: bool = True | |
class DeploymentConfig: | |
"""Deployment-specific configuration.""" | |
environment: str = "production" # development, staging, production | |
log_level: str = "INFO" | |
enable_health_checks: bool = True | |
max_memory_mb: int = 2000 | |
gpu_memory_fraction: float = 0.8 | |
class ConfigManager: | |
"""Centralized configuration manager.""" | |
def __init__(self, environment: str = "production"): | |
self.environment = environment | |
self._load_environment_config() | |
def _load_environment_config(self): | |
"""Load configuration based on environment.""" | |
if self.environment == "development": | |
self._load_dev_config() | |
elif self.environment == "staging": | |
self._load_staging_config() | |
else: | |
self._load_production_config() | |
def _load_production_config(self): | |
"""Production environment configuration.""" | |
self.text_processing = TextProcessingConfig( | |
max_chunk_length=200, | |
overlap_words=5, | |
translation_timeout=10, | |
enable_caching=True, | |
cache_size=1000 | |
) | |
self.model = ModelConfig( | |
device=self._auto_detect_device(), | |
use_mixed_precision=torch.cuda.is_available(), | |
cache_embeddings=True | |
) | |
self.audio_processing = AudioProcessingConfig( | |
crossfade_duration=0.1, | |
apply_noise_gate=True, | |
normalize_audio=True | |
) | |
self.pipeline = PipelineConfig( | |
enable_chunking=True, | |
apply_audio_processing=True, | |
enable_performance_tracking=True, | |
max_concurrent_requests=5 | |
) | |
self.deployment = DeploymentConfig( | |
environment="production", | |
log_level="INFO", | |
enable_health_checks=True, | |
max_memory_mb=2000 | |
) | |
def _load_dev_config(self): | |
"""Development environment configuration.""" | |
self.text_processing = TextProcessingConfig( | |
max_chunk_length=100, # Smaller chunks for testing | |
translation_timeout=5, # Shorter timeout for dev | |
cache_size=100 | |
) | |
self.model = ModelConfig( | |
device="cpu", # Force CPU for consistent dev testing | |
use_mixed_precision=False | |
) | |
self.audio_processing = AudioProcessingConfig( | |
crossfade_duration=0.05 # Shorter for faster testing | |
) | |
self.pipeline = PipelineConfig( | |
max_concurrent_requests=2 # Limited for dev | |
) | |
self.deployment = DeploymentConfig( | |
environment="development", | |
log_level="DEBUG", | |
max_memory_mb=1000 | |
) | |
def _load_staging_config(self): | |
"""Staging environment configuration.""" | |
# Similar to production but with more logging and smaller limits | |
self._load_production_config() | |
self.deployment.log_level = "DEBUG" | |
self.deployment.max_memory_mb = 1500 | |
self.pipeline.max_concurrent_requests = 3 | |
def _auto_detect_device(self) -> str: | |
"""Auto-detect optimal device for deployment.""" | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | |
return "mps" # Apple Silicon | |
else: | |
return "cpu" | |
def get_all_config(self) -> Dict[str, Any]: | |
"""Get all configuration as dictionary.""" | |
return { | |
"text_processing": self.text_processing.__dict__, | |
"model": self.model.__dict__, | |
"audio_processing": self.audio_processing.__dict__, | |
"pipeline": self.pipeline.__dict__, | |
"deployment": self.deployment.__dict__ | |
} | |
def update_from_env(self): | |
"""Update configuration from environment variables.""" | |
# Text processing | |
if os.getenv("TTS_MAX_CHUNK_LENGTH"): | |
self.text_processing.max_chunk_length = int(os.getenv("TTS_MAX_CHUNK_LENGTH")) | |
if os.getenv("TTS_TRANSLATION_TIMEOUT"): | |
self.text_processing.translation_timeout = int(os.getenv("TTS_TRANSLATION_TIMEOUT")) | |
# Model | |
if os.getenv("TTS_MODEL_CHECKPOINT"): | |
self.model.checkpoint = os.getenv("TTS_MODEL_CHECKPOINT") | |
if os.getenv("TTS_DEVICE"): | |
self.model.device = os.getenv("TTS_DEVICE") | |
if os.getenv("TTS_USE_MIXED_PRECISION"): | |
self.model.use_mixed_precision = os.getenv("TTS_USE_MIXED_PRECISION").lower() == "true" | |
# Audio processing | |
if os.getenv("TTS_CROSSFADE_DURATION"): | |
self.audio_processing.crossfade_duration = float(os.getenv("TTS_CROSSFADE_DURATION")) | |
# Pipeline | |
if os.getenv("TTS_MAX_CONCURRENT"): | |
self.pipeline.max_concurrent_requests = int(os.getenv("TTS_MAX_CONCURRENT")) | |
# Deployment | |
if os.getenv("TTS_LOG_LEVEL"): | |
self.deployment.log_level = os.getenv("TTS_LOG_LEVEL") | |
if os.getenv("TTS_MAX_MEMORY_MB"): | |
self.deployment.max_memory_mb = int(os.getenv("TTS_MAX_MEMORY_MB")) | |
# Global config instance | |
config = ConfigManager() | |
# Environment variable overrides | |
config.update_from_env() | |
def get_config() -> ConfigManager: | |
"""Get the global configuration instance.""" | |
return config | |
def update_config(environment: str): | |
"""Update configuration for specific environment.""" | |
global config | |
config = ConfigManager(environment) | |
config.update_from_env() | |
return config | |