flare / stt_factory.py
ciyidogan's picture
Update stt_factory.py
292760f verified
raw
history blame
3.47 kB
"""
STT Provider Factory for Flare
"""
from typing import Optional
from stt_interface import STTInterface, STTEngineType
from logger import log_info, log_error, log_warning, log_debug
from stt_google import GoogleCloudSTT
from config_provider import ConfigProvider
from stt_interface import STTInterface
# Import providers conditionally
stt_providers = {}
try:
from stt_google import GoogleCloudSTT
stt_providers['google'] = GoogleCloudSTT
except ImportError:
log_info("⚠️ Google Cloud STT not available")
try:
from stt_azure import AzureSTT
stt_providers['azure'] = AzureSTT
except ImportError:
log_error("⚠️ Azure STT not available")
try:
from stt_flicker import FlickerSTT
stt_providers['flicker'] = FlickerSTT
except ImportError:
log_error("⚠️ Flicker STT not available")
class NoSTT(STTInterface):
"""Dummy STT provider when STT is disabled"""
async def start_streaming(self, config) -> None:
pass
async def stream_audio(self, audio_chunk: bytes):
return
yield # Make it a generator
async def stop_streaming(self):
return None
def supports_realtime(self) -> bool:
return False
def get_supported_languages(self):
return []
def get_provider_name(self) -> str:
return "no_stt"
class STTFactory:
"""Factory for creating STT providers"""
@staticmethod
def create_provider() -> Optional[STTInterface]:
"""Create STT provider based on configuration"""
try:
cfg = ConfigProvider.get()
stt_provider_config = cfg.global_config.stt_provider
stt_engine = stt_provider_config.name
log_info(f"🎤 Creating STT provider: {stt_engine}")
if stt_engine == "no_stt":
return NoSTT()
# Get provider class
provider_class = stt_providers.get(stt_engine)
if not provider_class:
log_warning(f"⚠️ STT provider '{stt_engine}' not available")
return NoSTT()
# Get API key or credentials
api_key = stt_provider_config.api_key
if not api_key:
log_warning(f"⚠️ No API key configured for {stt_engine}")
return NoSTT()
# Create provider instance
if stt_engine == "google":
# For Google, api_key is the path to credentials JSON
return provider_class(credentials_path=api_key)
elif stt_engine == "azure":
# For Azure, parse the key format
parts = api_key.split('|')
if len(parts) != 2:
log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
return NoSTT()
return provider_class(subscription_key=parts[0], region=parts[1])
elif stt_engine == "flicker":
return provider_class(api_key=api_key)
else:
return provider_class(api_key=api_key)
except Exception as e:
log_error("❌ Failed to create STT provider", e)
return NoSTT()
@staticmethod
def get_available_providers():
"""Get list of available STT providers"""
return list(stt_providers.keys()) + ["no_stt"]