""" STT Provider Factory for Flare """ from typing import Optional from .stt_interface import STTInterface, STTEngineType from utils.logger import log_info, log_error, log_warning, log_debug from config.config_provider import ConfigProvider # Import providers conditionally stt_providers = {} try: from .stt_google import GoogleSTT stt_providers['google'] = GoogleSTT except ImportError: log_info("⚠️ Google Cloud STT not available") try: from .stt_deepgram import DeepgramSTT stt_providers['deepgram'] = DeepgramSTT except ImportError: log_info("⚠️ Deepgram STT not available") try: from .stt_azure import AzureSTT stt_providers['azure'] = AzureSTT except ImportError: log_error("⚠️ Azure STT not available") try: from .stt_flicker import FlickerSTT stt_providers['flicker'] = FlickerSTT except ImportError: log_error("⚠️ Flicker STT not available") class NoSTT(STTInterface): """Dummy STT provider when STT is disabled""" async def start_streaming(self, config) -> None: pass async def stream_audio(self, audio_chunk: bytes): return yield # Make it a generator async def stop_streaming(self): return None def supports_realtime(self) -> bool: return False def get_supported_languages(self): return [] def get_provider_name(self) -> str: return "no_stt" class STTFactory: """Factory for creating STT providers""" @staticmethod def create_provider() -> Optional[STTInterface]: """Create STT provider based on configuration""" try: cfg = ConfigProvider.get() stt_provider_config = cfg.global_config.stt_provider stt_engine = stt_provider_config.name log_info(f"🎤 Creating STT provider: {stt_engine}") if stt_engine == "no_stt": return NoSTT() # Get provider class provider_class = stt_providers.get(stt_engine) if not provider_class: log_warning(f"⚠️ STT provider '{stt_engine}' not available") return NoSTT() # Get API key or credentials api_key = STTFactory._get_api_key(stt_provider_config) if not api_key and stt_provider_config.requires_api_key: log_warning(f"⚠️ No API key configured for {stt_engine}") return NoSTT() # Create provider instance if stt_engine == "google": # For Google, api_key is the path to credentials JSON return provider_class(credentials_path=api_key) elif stt_engine == "deepgram": return provider_class(api_key=api_key) elif stt_engine == "azure": # For Azure, parse the key format parts = api_key.split('|') if len(parts) != 2: log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region") return NoSTT() return provider_class(subscription_key=parts[0], region=parts[1]) elif stt_engine == "flicker": return provider_class(api_key=api_key) else: return provider_class(api_key=api_key) except Exception as e: log_error("❌ Failed to create STT provider", e) return NoSTT() @staticmethod def get_available_providers(): """Get list of available STT providers""" return list(stt_providers.keys()) + ["no_stt"] @staticmethod def _get_api_key(stt_config) -> Optional[str]: """Get decrypted API key""" if not stt_config.api_key: return None if stt_config.api_key.startswith("enc:"): from utils.encryption_utils import decrypt return decrypt(stt_config.api_key) return stt_config.api_key