SpeechT5_hy / src /config.py
Edmon02's picture
Implement optimized TTS pipeline with advanced text preprocessing, audio processing, and comprehensive error handling
b163aa7
"""
Configuration Module for TTS Pipeline
=====================================
Centralized configuration management for all pipeline components.
"""
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any
import torch
@dataclass
class TextProcessingConfig:
"""Configuration for text processing components."""
max_chunk_length: int = 200
overlap_words: int = 5
translation_timeout: int = 10
enable_caching: bool = True
cache_size: int = 1000
@dataclass
class ModelConfig:
"""Configuration for TTS model components."""
checkpoint: str = "Edmon02/TTS_NB_2"
vocoder_checkpoint: str = "microsoft/speecht5_hifigan"
device: Optional[str] = None
use_mixed_precision: bool = True
cache_embeddings: bool = True
max_text_positions: int = 600
@dataclass
class AudioProcessingConfig:
"""Configuration for audio processing components."""
crossfade_duration: float = 0.1
sample_rate: int = 16000
apply_noise_gate: bool = True
normalize_audio: bool = True
noise_gate_threshold_db: float = -40.0
target_peak: float = 0.95
@dataclass
class PipelineConfig:
"""Main pipeline configuration."""
enable_chunking: bool = True
apply_audio_processing: bool = True
enable_performance_tracking: bool = True
max_concurrent_requests: int = 5
warmup_on_init: bool = True
@dataclass
class DeploymentConfig:
"""Deployment-specific configuration."""
environment: str = "production" # development, staging, production
log_level: str = "INFO"
enable_health_checks: bool = True
max_memory_mb: int = 2000
gpu_memory_fraction: float = 0.8
class ConfigManager:
"""Centralized configuration manager."""
def __init__(self, environment: str = "production"):
self.environment = environment
self._load_environment_config()
def _load_environment_config(self):
"""Load configuration based on environment."""
if self.environment == "development":
self._load_dev_config()
elif self.environment == "staging":
self._load_staging_config()
else:
self._load_production_config()
def _load_production_config(self):
"""Production environment configuration."""
self.text_processing = TextProcessingConfig(
max_chunk_length=200,
overlap_words=5,
translation_timeout=10,
enable_caching=True,
cache_size=1000
)
self.model = ModelConfig(
device=self._auto_detect_device(),
use_mixed_precision=torch.cuda.is_available(),
cache_embeddings=True
)
self.audio_processing = AudioProcessingConfig(
crossfade_duration=0.1,
apply_noise_gate=True,
normalize_audio=True
)
self.pipeline = PipelineConfig(
enable_chunking=True,
apply_audio_processing=True,
enable_performance_tracking=True,
max_concurrent_requests=5
)
self.deployment = DeploymentConfig(
environment="production",
log_level="INFO",
enable_health_checks=True,
max_memory_mb=2000
)
def _load_dev_config(self):
"""Development environment configuration."""
self.text_processing = TextProcessingConfig(
max_chunk_length=100, # Smaller chunks for testing
translation_timeout=5, # Shorter timeout for dev
cache_size=100
)
self.model = ModelConfig(
device="cpu", # Force CPU for consistent dev testing
use_mixed_precision=False
)
self.audio_processing = AudioProcessingConfig(
crossfade_duration=0.05 # Shorter for faster testing
)
self.pipeline = PipelineConfig(
max_concurrent_requests=2 # Limited for dev
)
self.deployment = DeploymentConfig(
environment="development",
log_level="DEBUG",
max_memory_mb=1000
)
def _load_staging_config(self):
"""Staging environment configuration."""
# Similar to production but with more logging and smaller limits
self._load_production_config()
self.deployment.log_level = "DEBUG"
self.deployment.max_memory_mb = 1500
self.pipeline.max_concurrent_requests = 3
def _auto_detect_device(self) -> str:
"""Auto-detect optimal device for deployment."""
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps" # Apple Silicon
else:
return "cpu"
def get_all_config(self) -> Dict[str, Any]:
"""Get all configuration as dictionary."""
return {
"text_processing": self.text_processing.__dict__,
"model": self.model.__dict__,
"audio_processing": self.audio_processing.__dict__,
"pipeline": self.pipeline.__dict__,
"deployment": self.deployment.__dict__
}
def update_from_env(self):
"""Update configuration from environment variables."""
# Text processing
if os.getenv("TTS_MAX_CHUNK_LENGTH"):
self.text_processing.max_chunk_length = int(os.getenv("TTS_MAX_CHUNK_LENGTH"))
if os.getenv("TTS_TRANSLATION_TIMEOUT"):
self.text_processing.translation_timeout = int(os.getenv("TTS_TRANSLATION_TIMEOUT"))
# Model
if os.getenv("TTS_MODEL_CHECKPOINT"):
self.model.checkpoint = os.getenv("TTS_MODEL_CHECKPOINT")
if os.getenv("TTS_DEVICE"):
self.model.device = os.getenv("TTS_DEVICE")
if os.getenv("TTS_USE_MIXED_PRECISION"):
self.model.use_mixed_precision = os.getenv("TTS_USE_MIXED_PRECISION").lower() == "true"
# Audio processing
if os.getenv("TTS_CROSSFADE_DURATION"):
self.audio_processing.crossfade_duration = float(os.getenv("TTS_CROSSFADE_DURATION"))
# Pipeline
if os.getenv("TTS_MAX_CONCURRENT"):
self.pipeline.max_concurrent_requests = int(os.getenv("TTS_MAX_CONCURRENT"))
# Deployment
if os.getenv("TTS_LOG_LEVEL"):
self.deployment.log_level = os.getenv("TTS_LOG_LEVEL")
if os.getenv("TTS_MAX_MEMORY_MB"):
self.deployment.max_memory_mb = int(os.getenv("TTS_MAX_MEMORY_MB"))
# Global config instance
config = ConfigManager()
# Environment variable overrides
config.update_from_env()
def get_config() -> ConfigManager:
"""Get the global configuration instance."""
return config
def update_config(environment: str):
"""Update configuration for specific environment."""
global config
config = ConfigManager(environment)
config.update_from_env()
return config