Spaces:
Building
Building
""" | |
STT Provider Factory for Flare | |
""" | |
from typing import Optional | |
from stt_interface import STTInterface, STTEngineType | |
from logger import log_info, log_error, log_warning, log_debug | |
from stt_google import GoogleCloudSTT | |
from config_provider import ConfigProvider | |
from stt_interface import STTInterface | |
# Import providers conditionally | |
stt_providers = {} | |
try: | |
from stt_google import GoogleCloudSTT | |
stt_providers['google'] = GoogleCloudSTT | |
except ImportError: | |
log_info("⚠️ Google Cloud STT not available") | |
try: | |
from stt_azure import AzureSTT | |
stt_providers['azure'] = AzureSTT | |
except ImportError: | |
log_error("⚠️ Azure STT not available") | |
try: | |
from stt_flicker import FlickerSTT | |
stt_providers['flicker'] = FlickerSTT | |
except ImportError: | |
log_error("⚠️ Flicker STT not available") | |
class NoSTT(STTInterface): | |
"""Dummy STT provider when STT is disabled""" | |
async def start_streaming(self, config) -> None: | |
pass | |
async def stream_audio(self, audio_chunk: bytes): | |
return | |
yield # Make it a generator | |
async def stop_streaming(self): | |
return None | |
def supports_realtime(self) -> bool: | |
return False | |
def get_supported_languages(self): | |
return [] | |
def get_provider_name(self) -> str: | |
return "no_stt" | |
class STTFactory: | |
"""Factory for creating STT providers""" | |
def create_provider() -> Optional[STTInterface]: | |
"""Create STT provider based on configuration""" | |
try: | |
cfg = ConfigProvider.get() | |
stt_provider_config = cfg.global_config.stt_provider | |
stt_engine = stt_provider_config.name | |
log_info(f"🎤 Creating STT provider: {stt_engine}") | |
if stt_engine == "no_stt": | |
return NoSTT() | |
# Get provider class | |
provider_class = stt_providers.get(stt_engine) | |
if not provider_class: | |
log_warning(f"⚠️ STT provider '{stt_engine}' not available") | |
return NoSTT() | |
# Get API key or credentials | |
api_key = stt_provider_config.api_key | |
if not api_key: | |
log_warning(f"⚠️ No API key configured for {stt_engine}") | |
return NoSTT() | |
# Create provider instance | |
if stt_engine == "google": | |
# For Google, api_key is the path to credentials JSON | |
return provider_class(credentials_path=api_key) | |
elif stt_engine == "azure": | |
# For Azure, parse the key format | |
parts = api_key.split('|') | |
if len(parts) != 2: | |
log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region") | |
return NoSTT() | |
return provider_class(subscription_key=parts[0], region=parts[1]) | |
elif stt_engine == "flicker": | |
return provider_class(api_key=api_key) | |
else: | |
return provider_class(api_key=api_key) | |
except Exception as e: | |
log_error("❌ Failed to create STT provider", e) | |
return NoSTT() | |
def get_available_providers(): | |
"""Get list of available STT providers""" | |
return list(stt_providers.keys()) + ["no_stt"] |