Spaces:
Building
Building
File size: 3,473 Bytes
a847f43 b236687 a847f43 b236687 98796a5 8eb3adf a847f43 b236687 321ebc6 a847f43 321ebc6 8eb3adf 321ebc6 8eb3adf 321ebc6 8eb3adf 321ebc6 a847f43 321ebc6 a847f43 b236687 321ebc6 292760f 321ebc6 8eb3adf 321ebc6 292760f 321ebc6 292760f 321ebc6 292760f 321ebc6 292760f 321ebc6 8eb3adf 321ebc6 b236687 321ebc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
"""
STT Provider Factory for Flare
"""
from typing import Optional
from stt_interface import STTInterface, STTEngineType
from logger import log_info, log_error, log_warning, log_debug
from stt_google import GoogleCloudSTT
from config_provider import ConfigProvider
from stt_interface import STTInterface
# Import providers conditionally
stt_providers = {}
try:
from stt_google import GoogleCloudSTT
stt_providers['google'] = GoogleCloudSTT
except ImportError:
log_info("⚠️ Google Cloud STT not available")
try:
from stt_azure import AzureSTT
stt_providers['azure'] = AzureSTT
except ImportError:
log_error("⚠️ Azure STT not available")
try:
from stt_flicker import FlickerSTT
stt_providers['flicker'] = FlickerSTT
except ImportError:
log_error("⚠️ Flicker STT not available")
class NoSTT(STTInterface):
"""Dummy STT provider when STT is disabled"""
async def start_streaming(self, config) -> None:
pass
async def stream_audio(self, audio_chunk: bytes):
return
yield # Make it a generator
async def stop_streaming(self):
return None
def supports_realtime(self) -> bool:
return False
def get_supported_languages(self):
return []
def get_provider_name(self) -> str:
return "no_stt"
class STTFactory:
"""Factory for creating STT providers"""
@staticmethod
def create_provider() -> Optional[STTInterface]:
"""Create STT provider based on configuration"""
try:
cfg = ConfigProvider.get()
stt_provider_config = cfg.global_config.stt_provider
stt_engine = stt_provider_config.name
log_info(f"🎤 Creating STT provider: {stt_engine}")
if stt_engine == "no_stt":
return NoSTT()
# Get provider class
provider_class = stt_providers.get(stt_engine)
if not provider_class:
log_warning(f"⚠️ STT provider '{stt_engine}' not available")
return NoSTT()
# Get API key or credentials
api_key = stt_provider_config.api_key
if not api_key:
log_warning(f"⚠️ No API key configured for {stt_engine}")
return NoSTT()
# Create provider instance
if stt_engine == "google":
# For Google, api_key is the path to credentials JSON
return provider_class(credentials_path=api_key)
elif stt_engine == "azure":
# For Azure, parse the key format
parts = api_key.split('|')
if len(parts) != 2:
log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
return NoSTT()
return provider_class(subscription_key=parts[0], region=parts[1])
elif stt_engine == "flicker":
return provider_class(api_key=api_key)
else:
return provider_class(api_key=api_key)
except Exception as e:
log_error("❌ Failed to create STT provider", e)
return NoSTT()
@staticmethod
def get_available_providers():
"""Get list of available STT providers"""
return list(stt_providers.keys()) + ["no_stt"] |