File size: 3,958 Bytes
ff2d19f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82b3923
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
STT Provider Factory for Flare
"""
from typing import Optional
from .stt_interface import STTInterface, STTEngineType
from utils.logger import log_info, log_error, log_warning, log_debug
from config.config_provider import ConfigProvider

# Import providers conditionally
stt_providers = {}

try:
    from .stt_google import GoogleSTT
    stt_providers['google'] = GoogleSTT
except ImportError:
    log_info("⚠️ Google Cloud STT not available")

try:
    from .stt_deepgram import DeepgramSTT
    stt_providers['deepgram'] = DeepgramSTT
except ImportError:
    log_info("⚠️ Deepgram STT not available")

try:
    from .stt_azure import AzureSTT
    stt_providers['azure'] = AzureSTT
except ImportError:
    log_error("⚠️ Azure STT not available")

try:
    from .stt_flicker import FlickerSTT
    stt_providers['flicker'] = FlickerSTT
except ImportError:
    log_error("⚠️ Flicker STT not available")

class NoSTT(STTInterface):
    """Dummy STT provider when STT is disabled"""

    async def start_streaming(self, config) -> None:
        pass

    async def stream_audio(self, audio_chunk: bytes):
        return
        yield  # Make it a generator

    async def stop_streaming(self):
        return None

    def supports_realtime(self) -> bool:
        return False

    def get_supported_languages(self):
        return []

    def get_provider_name(self) -> str:
        return "no_stt"

class STTFactory:
    """Factory for creating STT providers"""

    @staticmethod
    def create_provider() -> Optional[STTInterface]:
        """Create STT provider based on configuration"""
        try:
            cfg = ConfigProvider.get()
            stt_provider_config = cfg.global_config.stt_provider
            stt_engine = stt_provider_config.name

            log_info(f"🎤 Creating STT provider: {stt_engine}")

            if stt_engine == "no_stt":
                return NoSTT()

            # Get provider class
            provider_class = stt_providers.get(stt_engine)
            if not provider_class:
                log_warning(f"⚠️ STT provider '{stt_engine}' not available")
                return NoSTT()

            # Get API key or credentials
            api_key = STTFactory._get_api_key(stt_provider_config)

            if not api_key and stt_provider_config.requires_api_key:
                log_warning(f"⚠️ No API key configured for {stt_engine}")
                return NoSTT()
            
            # Create provider instance
            if stt_engine == "google":
                # For Google, api_key is the path to credentials JSON
                return provider_class(credentials_path=api_key)
            elif stt_engine == "deepgram":
                return provider_class(api_key=api_key)
            elif stt_engine == "azure":
                # For Azure, parse the key format
                parts = api_key.split('|')
                if len(parts) != 2:
                    log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
                    return NoSTT()
                return provider_class(subscription_key=parts[0], region=parts[1])
            elif stt_engine == "flicker":
                return provider_class(api_key=api_key)
            else:
                return provider_class(api_key=api_key)

        except Exception as e:
            log_error("❌ Failed to create STT provider", e)
            return NoSTT()

    @staticmethod
    def get_available_providers():
        """Get list of available STT providers"""
        return list(stt_providers.keys()) + ["no_stt"]

    @staticmethod
    def _get_api_key(stt_config) -> Optional[str]:
        """Get decrypted API key"""
        if not stt_config.api_key:
            return None

        if stt_config.api_key.startswith("enc:"):
            from utils.encryption_utils import decrypt
            return decrypt(stt_config.api_key)

        return stt_config.api_key