File size: 3,330 Bytes
a847f43
b236687
a847f43
b236687
98796a5
 
a847f43
b236687
321ebc6
a847f43
321ebc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a847f43
321ebc6
a847f43
 
b236687
 
321ebc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b236687
 
321ebc6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
STT Provider Factory for Flare
"""
from typing import Optional
from stt_interface import STTInterface, STTEngineType
from utils import log
from stt_google import GoogleCloudSTT
from config_provider import ConfigProvider
from stt_interface import STTInterface

# Import providers conditionally
stt_providers = {}

try:
    from stt_google import GoogleCloudSTT
    stt_providers['google'] = GoogleCloudSTT
except ImportError:
    log("⚠️ Google Cloud STT not available")

try:
    from stt_azure import AzureSTT
    stt_providers['azure'] = AzureSTT
except ImportError:
    log("⚠️ Azure STT not available")

try:
    from stt_flicker import FlickerSTT
    stt_providers['flicker'] = FlickerSTT
except ImportError:
    log("⚠️ Flicker STT not available")

class NoSTT(STTInterface):
    """Dummy STT provider when STT is disabled"""
    
    async def start_streaming(self, config) -> None:
        pass
    
    async def stream_audio(self, audio_chunk: bytes):
        return
        yield  # Make it a generator
    
    async def stop_streaming(self):
        return None
    
    def supports_realtime(self) -> bool:
        return False
    
    def get_supported_languages(self):
        return []
    
    def get_provider_name(self) -> str:
        return "no_stt"
        
class STTFactory:
    """Factory for creating STT providers"""
    
    @staticmethod
    def create_provider() -> Optional[STTInterface]:
        """Create STT provider based on configuration"""
        try:
            cfg = ConfigProvider.get()
            stt_engine = cfg.global_config.stt_engine
            
            log(f"🎤 Creating STT provider: {stt_engine}")
            
            if stt_engine == "no_stt":
                return NoSTT()
            
            # Get provider class
            provider_class = stt_providers.get(stt_engine)
            if not provider_class:
                log(f"⚠️ STT provider '{stt_engine}' not available")
                return NoSTT()
            
            # Get API key or credentials
            api_key = cfg.global_config.get_stt_api_key()
            
            if not api_key:
                log(f"⚠️ No API key configured for {stt_engine}")
                return NoSTT()
            
            # Create provider instance
            if stt_engine == "google":
                # For Google, api_key is the path to credentials JSON
                return provider_class(credentials_path=api_key)
            elif stt_engine == "azure":
                # For Azure, parse the key format
                parts = api_key.split('|')
                if len(parts) != 2:
                    log("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
                    return NoSTT()
                return provider_class(subscription_key=parts[0], region=parts[1])
            elif stt_engine == "flicker":
                return provider_class(api_key=api_key)
            else:
                return provider_class(api_key=api_key)
                
        except Exception as e:
            log(f"❌ Failed to create STT provider: {e}")
            return NoSTT()
    
    @staticmethod
    def get_available_providers():
        """Get list of available STT providers"""
        return list(stt_providers.keys()) + ["no_stt"]