Spaces:
Building
Building
File size: 3,330 Bytes
a847f43 b236687 a847f43 b236687 98796a5 a847f43 b236687 321ebc6 a847f43 321ebc6 a847f43 321ebc6 a847f43 b236687 321ebc6 b236687 321ebc6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
"""
STT Provider Factory for Flare
"""
from typing import Optional
from stt_interface import STTInterface, STTEngineType
from utils import log
from stt_google import GoogleCloudSTT
from config_provider import ConfigProvider
from stt_interface import STTInterface
# Import providers conditionally
stt_providers = {}
try:
from stt_google import GoogleCloudSTT
stt_providers['google'] = GoogleCloudSTT
except ImportError:
log("⚠️ Google Cloud STT not available")
try:
from stt_azure import AzureSTT
stt_providers['azure'] = AzureSTT
except ImportError:
log("⚠️ Azure STT not available")
try:
from stt_flicker import FlickerSTT
stt_providers['flicker'] = FlickerSTT
except ImportError:
log("⚠️ Flicker STT not available")
class NoSTT(STTInterface):
"""Dummy STT provider when STT is disabled"""
async def start_streaming(self, config) -> None:
pass
async def stream_audio(self, audio_chunk: bytes):
return
yield # Make it a generator
async def stop_streaming(self):
return None
def supports_realtime(self) -> bool:
return False
def get_supported_languages(self):
return []
def get_provider_name(self) -> str:
return "no_stt"
class STTFactory:
"""Factory for creating STT providers"""
@staticmethod
def create_provider() -> Optional[STTInterface]:
"""Create STT provider based on configuration"""
try:
cfg = ConfigProvider.get()
stt_engine = cfg.global_config.stt_engine
log(f"🎤 Creating STT provider: {stt_engine}")
if stt_engine == "no_stt":
return NoSTT()
# Get provider class
provider_class = stt_providers.get(stt_engine)
if not provider_class:
log(f"⚠️ STT provider '{stt_engine}' not available")
return NoSTT()
# Get API key or credentials
api_key = cfg.global_config.get_stt_api_key()
if not api_key:
log(f"⚠️ No API key configured for {stt_engine}")
return NoSTT()
# Create provider instance
if stt_engine == "google":
# For Google, api_key is the path to credentials JSON
return provider_class(credentials_path=api_key)
elif stt_engine == "azure":
# For Azure, parse the key format
parts = api_key.split('|')
if len(parts) != 2:
log("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
return NoSTT()
return provider_class(subscription_key=parts[0], region=parts[1])
elif stt_engine == "flicker":
return provider_class(api_key=api_key)
else:
return provider_class(api_key=api_key)
except Exception as e:
log(f"❌ Failed to create STT provider: {e}")
return NoSTT()
@staticmethod
def get_available_providers():
"""Get list of available STT providers"""
return list(stt_providers.keys()) + ["no_stt"] |