File size: 3,473 Bytes
a847f43
b236687
a847f43
b236687
98796a5
8eb3adf
a847f43
b236687
321ebc6
a847f43
321ebc6
 
 
 
 
 
 
8eb3adf
321ebc6
 
 
 
 
8eb3adf
321ebc6
 
 
 
 
8eb3adf
321ebc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a847f43
321ebc6
a847f43
 
b236687
 
321ebc6
 
292760f
 
321ebc6
8eb3adf
321ebc6
 
 
 
 
 
 
292760f
321ebc6
 
 
292760f
321ebc6
 
292760f
321ebc6
 
 
 
 
 
 
 
 
 
292760f
321ebc6
 
 
 
 
 
 
 
8eb3adf
321ebc6
b236687
 
321ebc6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
STT Provider Factory for Flare
"""
from typing import Optional
from stt_interface import STTInterface, STTEngineType
from logger import log_info, log_error, log_warning, log_debug
from stt_google import GoogleCloudSTT
from config_provider import ConfigProvider
from stt_interface import STTInterface

# Import providers conditionally
stt_providers = {}

try:
    from stt_google import GoogleCloudSTT
    stt_providers['google'] = GoogleCloudSTT
except ImportError:
    log_info("⚠️ Google Cloud STT not available")

try:
    from stt_azure import AzureSTT
    stt_providers['azure'] = AzureSTT
except ImportError:
    log_error("⚠️ Azure STT not available")

try:
    from stt_flicker import FlickerSTT
    stt_providers['flicker'] = FlickerSTT
except ImportError:
    log_error("⚠️ Flicker STT not available")

class NoSTT(STTInterface):
    """Dummy STT provider when STT is disabled"""
    
    async def start_streaming(self, config) -> None:
        pass
    
    async def stream_audio(self, audio_chunk: bytes):
        return
        yield  # Make it a generator
    
    async def stop_streaming(self):
        return None
    
    def supports_realtime(self) -> bool:
        return False
    
    def get_supported_languages(self):
        return []
    
    def get_provider_name(self) -> str:
        return "no_stt"
        
class STTFactory:
    """Factory for creating STT providers"""
    
    @staticmethod
    def create_provider() -> Optional[STTInterface]:
        """Create STT provider based on configuration"""
        try:
            cfg = ConfigProvider.get()
            stt_provider_config = cfg.global_config.stt_provider
            stt_engine = stt_provider_config.name
            
            log_info(f"🎤 Creating STT provider: {stt_engine}")
            
            if stt_engine == "no_stt":
                return NoSTT()
            
            # Get provider class
            provider_class = stt_providers.get(stt_engine)
            if not provider_class:
                log_warning(f"⚠️ STT provider '{stt_engine}' not available")
                return NoSTT()
            
            # Get API key or credentials
            api_key = stt_provider_config.api_key
            
            if not api_key:
                log_warning(f"⚠️ No API key configured for {stt_engine}")
                return NoSTT()
            
            # Create provider instance
            if stt_engine == "google":
                # For Google, api_key is the path to credentials JSON
                return provider_class(credentials_path=api_key)
            elif stt_engine == "azure":
                # For Azure, parse the key format
                parts = api_key.split('|')
                if len(parts) != 2:
                    log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
                    return NoSTT()
                return provider_class(subscription_key=parts[0], region=parts[1])
            elif stt_engine == "flicker":
                return provider_class(api_key=api_key)
            else:
                return provider_class(api_key=api_key)
                
        except Exception as e:
            log_error("❌ Failed to create STT provider", e)
            return NoSTT()
    
    @staticmethod
    def get_available_providers():
        """Get list of available STT providers"""
        return list(stt_providers.keys()) + ["no_stt"]