Spaces:
Building
Building
File size: 3,958 Bytes
ff2d19f 82b3923 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
"""
STT Provider Factory for Flare
"""
from typing import Optional
from .stt_interface import STTInterface, STTEngineType
from utils.logger import log_info, log_error, log_warning, log_debug
from config.config_provider import ConfigProvider
# Import providers conditionally
stt_providers = {}
try:
from .stt_google import GoogleSTT
stt_providers['google'] = GoogleSTT
except ImportError:
log_info("⚠️ Google Cloud STT not available")
try:
from .stt_deepgram import DeepgramSTT
stt_providers['deepgram'] = DeepgramSTT
except ImportError:
log_info("⚠️ Deepgram STT not available")
try:
from .stt_azure import AzureSTT
stt_providers['azure'] = AzureSTT
except ImportError:
log_error("⚠️ Azure STT not available")
try:
from .stt_flicker import FlickerSTT
stt_providers['flicker'] = FlickerSTT
except ImportError:
log_error("⚠️ Flicker STT not available")
class NoSTT(STTInterface):
"""Dummy STT provider when STT is disabled"""
async def start_streaming(self, config) -> None:
pass
async def stream_audio(self, audio_chunk: bytes):
return
yield # Make it a generator
async def stop_streaming(self):
return None
def supports_realtime(self) -> bool:
return False
def get_supported_languages(self):
return []
def get_provider_name(self) -> str:
return "no_stt"
class STTFactory:
"""Factory for creating STT providers"""
@staticmethod
def create_provider() -> Optional[STTInterface]:
"""Create STT provider based on configuration"""
try:
cfg = ConfigProvider.get()
stt_provider_config = cfg.global_config.stt_provider
stt_engine = stt_provider_config.name
log_info(f"🎤 Creating STT provider: {stt_engine}")
if stt_engine == "no_stt":
return NoSTT()
# Get provider class
provider_class = stt_providers.get(stt_engine)
if not provider_class:
log_warning(f"⚠️ STT provider '{stt_engine}' not available")
return NoSTT()
# Get API key or credentials
api_key = STTFactory._get_api_key(stt_provider_config)
if not api_key and stt_provider_config.requires_api_key:
log_warning(f"⚠️ No API key configured for {stt_engine}")
return NoSTT()
# Create provider instance
if stt_engine == "google":
# For Google, api_key is the path to credentials JSON
return provider_class(credentials_path=api_key)
elif stt_engine == "deepgram":
return provider_class(api_key=api_key)
elif stt_engine == "azure":
# For Azure, parse the key format
parts = api_key.split('|')
if len(parts) != 2:
log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
return NoSTT()
return provider_class(subscription_key=parts[0], region=parts[1])
elif stt_engine == "flicker":
return provider_class(api_key=api_key)
else:
return provider_class(api_key=api_key)
except Exception as e:
log_error("❌ Failed to create STT provider", e)
return NoSTT()
@staticmethod
def get_available_providers():
"""Get list of available STT providers"""
return list(stt_providers.keys()) + ["no_stt"]
@staticmethod
def _get_api_key(stt_config) -> Optional[str]:
"""Get decrypted API key"""
if not stt_config.api_key:
return None
if stt_config.api_key.startswith("enc:"):
from utils.encryption_utils import decrypt
return decrypt(stt_config.api_key)
return stt_config.api_key |