Spaces:
Runtime error
Runtime error
File size: 7,110 Bytes
b163aa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
"""
Configuration Module for TTS Pipeline
=====================================
Centralized configuration management for all pipeline components.
"""
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any
import torch
@dataclass
class TextProcessingConfig:
"""Configuration for text processing components."""
max_chunk_length: int = 200
overlap_words: int = 5
translation_timeout: int = 10
enable_caching: bool = True
cache_size: int = 1000
@dataclass
class ModelConfig:
"""Configuration for TTS model components."""
checkpoint: str = "Edmon02/TTS_NB_2"
vocoder_checkpoint: str = "microsoft/speecht5_hifigan"
device: Optional[str] = None
use_mixed_precision: bool = True
cache_embeddings: bool = True
max_text_positions: int = 600
@dataclass
class AudioProcessingConfig:
"""Configuration for audio processing components."""
crossfade_duration: float = 0.1
sample_rate: int = 16000
apply_noise_gate: bool = True
normalize_audio: bool = True
noise_gate_threshold_db: float = -40.0
target_peak: float = 0.95
@dataclass
class PipelineConfig:
"""Main pipeline configuration."""
enable_chunking: bool = True
apply_audio_processing: bool = True
enable_performance_tracking: bool = True
max_concurrent_requests: int = 5
warmup_on_init: bool = True
@dataclass
class DeploymentConfig:
"""Deployment-specific configuration."""
environment: str = "production" # development, staging, production
log_level: str = "INFO"
enable_health_checks: bool = True
max_memory_mb: int = 2000
gpu_memory_fraction: float = 0.8
class ConfigManager:
"""Centralized configuration manager."""
def __init__(self, environment: str = "production"):
self.environment = environment
self._load_environment_config()
def _load_environment_config(self):
"""Load configuration based on environment."""
if self.environment == "development":
self._load_dev_config()
elif self.environment == "staging":
self._load_staging_config()
else:
self._load_production_config()
def _load_production_config(self):
"""Production environment configuration."""
self.text_processing = TextProcessingConfig(
max_chunk_length=200,
overlap_words=5,
translation_timeout=10,
enable_caching=True,
cache_size=1000
)
self.model = ModelConfig(
device=self._auto_detect_device(),
use_mixed_precision=torch.cuda.is_available(),
cache_embeddings=True
)
self.audio_processing = AudioProcessingConfig(
crossfade_duration=0.1,
apply_noise_gate=True,
normalize_audio=True
)
self.pipeline = PipelineConfig(
enable_chunking=True,
apply_audio_processing=True,
enable_performance_tracking=True,
max_concurrent_requests=5
)
self.deployment = DeploymentConfig(
environment="production",
log_level="INFO",
enable_health_checks=True,
max_memory_mb=2000
)
def _load_dev_config(self):
"""Development environment configuration."""
self.text_processing = TextProcessingConfig(
max_chunk_length=100, # Smaller chunks for testing
translation_timeout=5, # Shorter timeout for dev
cache_size=100
)
self.model = ModelConfig(
device="cpu", # Force CPU for consistent dev testing
use_mixed_precision=False
)
self.audio_processing = AudioProcessingConfig(
crossfade_duration=0.05 # Shorter for faster testing
)
self.pipeline = PipelineConfig(
max_concurrent_requests=2 # Limited for dev
)
self.deployment = DeploymentConfig(
environment="development",
log_level="DEBUG",
max_memory_mb=1000
)
def _load_staging_config(self):
"""Staging environment configuration."""
# Similar to production but with more logging and smaller limits
self._load_production_config()
self.deployment.log_level = "DEBUG"
self.deployment.max_memory_mb = 1500
self.pipeline.max_concurrent_requests = 3
def _auto_detect_device(self) -> str:
"""Auto-detect optimal device for deployment."""
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "mps" # Apple Silicon
else:
return "cpu"
def get_all_config(self) -> Dict[str, Any]:
"""Get all configuration as dictionary."""
return {
"text_processing": self.text_processing.__dict__,
"model": self.model.__dict__,
"audio_processing": self.audio_processing.__dict__,
"pipeline": self.pipeline.__dict__,
"deployment": self.deployment.__dict__
}
def update_from_env(self):
"""Update configuration from environment variables."""
# Text processing
if os.getenv("TTS_MAX_CHUNK_LENGTH"):
self.text_processing.max_chunk_length = int(os.getenv("TTS_MAX_CHUNK_LENGTH"))
if os.getenv("TTS_TRANSLATION_TIMEOUT"):
self.text_processing.translation_timeout = int(os.getenv("TTS_TRANSLATION_TIMEOUT"))
# Model
if os.getenv("TTS_MODEL_CHECKPOINT"):
self.model.checkpoint = os.getenv("TTS_MODEL_CHECKPOINT")
if os.getenv("TTS_DEVICE"):
self.model.device = os.getenv("TTS_DEVICE")
if os.getenv("TTS_USE_MIXED_PRECISION"):
self.model.use_mixed_precision = os.getenv("TTS_USE_MIXED_PRECISION").lower() == "true"
# Audio processing
if os.getenv("TTS_CROSSFADE_DURATION"):
self.audio_processing.crossfade_duration = float(os.getenv("TTS_CROSSFADE_DURATION"))
# Pipeline
if os.getenv("TTS_MAX_CONCURRENT"):
self.pipeline.max_concurrent_requests = int(os.getenv("TTS_MAX_CONCURRENT"))
# Deployment
if os.getenv("TTS_LOG_LEVEL"):
self.deployment.log_level = os.getenv("TTS_LOG_LEVEL")
if os.getenv("TTS_MAX_MEMORY_MB"):
self.deployment.max_memory_mb = int(os.getenv("TTS_MAX_MEMORY_MB"))
# Global config instance
config = ConfigManager()
# Environment variable overrides
config.update_from_env()
def get_config() -> ConfigManager:
"""Get the global configuration instance."""
return config
def update_config(environment: str):
"""Update configuration for specific environment."""
global config
config = ConfigManager(environment)
config.update_from_env()
return config
|