ciyidogan commited on
Commit
ff2d19f
·
verified ·
1 Parent(s): 78b5a88

Update stt/stt_factory.py

Browse files
Files changed (1) hide show
  1. stt/stt_factory.py +123 -124
stt/stt_factory.py CHANGED
@@ -1,125 +1,124 @@
1
- """
2
- STT Provider Factory for Flare
3
- """
4
- from typing import Optional
5
- from .stt_interface import STTInterface, STTEngineType
6
- from utils.logger import log_info, log_error, log_warning, log_debug
7
- from .stt_google import GoogleCloudSTT
8
- from config.config_provider import ConfigProvider
9
-
10
- # Import providers conditionally
11
- stt_providers = {}
12
-
13
- try:
14
- from .stt_google import GoogleCloudSTT
15
- stt_providers['google'] = GoogleCloudSTT
16
- except ImportError:
17
- log_info("⚠️ Google Cloud STT not available")
18
-
19
- try:
20
- from .stt_deepgram import DeepgramSTT
21
- stt_providers['deepgram'] = DeepgramSTT
22
- except ImportError:
23
- log_info("⚠️ Deepgram STT not available")
24
-
25
- try:
26
- from .stt_azure import AzureSTT
27
- stt_providers['azure'] = AzureSTT
28
- except ImportError:
29
- log_error("⚠️ Azure STT not available")
30
-
31
- try:
32
- from .stt_flicker import FlickerSTT
33
- stt_providers['flicker'] = FlickerSTT
34
- except ImportError:
35
- log_error("⚠️ Flicker STT not available")
36
-
37
- class NoSTT(STTInterface):
38
- """Dummy STT provider when STT is disabled"""
39
-
40
- async def start_streaming(self, config) -> None:
41
- pass
42
-
43
- async def stream_audio(self, audio_chunk: bytes):
44
- return
45
- yield # Make it a generator
46
-
47
- async def stop_streaming(self):
48
- return None
49
-
50
- def supports_realtime(self) -> bool:
51
- return False
52
-
53
- def get_supported_languages(self):
54
- return []
55
-
56
- def get_provider_name(self) -> str:
57
- return "no_stt"
58
-
59
- class STTFactory:
60
- """Factory for creating STT providers"""
61
-
62
- @staticmethod
63
- def create_provider() -> Optional[STTInterface]:
64
- """Create STT provider based on configuration"""
65
- try:
66
- cfg = ConfigProvider.get()
67
- stt_provider_config = cfg.global_config.stt_provider
68
- stt_engine = stt_provider_config.name
69
-
70
- log_info(f"🎤 Creating STT provider: {stt_engine}")
71
-
72
- if stt_engine == "no_stt":
73
- return NoSTT()
74
-
75
- # Get provider class
76
- provider_class = stt_providers.get(stt_engine)
77
- if not provider_class:
78
- log_warning(f"⚠️ STT provider '{stt_engine}' not available")
79
- return NoSTT()
80
-
81
- # Get API key or credentials
82
- api_key = STTFactory._get_api_key(stt_provider_config)
83
-
84
- if not api_key and stt_provider_config.requires_api_key:
85
- log_warning(f"⚠️ No API key configured for {stt_engine}")
86
- return NoSTT()
87
-
88
- # Create provider instance
89
- if stt_engine == "google":
90
- # For Google, api_key is the path to credentials JSON
91
- return provider_class(credentials_path=api_key)
92
- elif stt_engine == "deepgram":
93
- return provider_class(api_key=api_key)
94
- elif stt_engine == "azure":
95
- # For Azure, parse the key format
96
- parts = api_key.split('|')
97
- if len(parts) != 2:
98
- log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
99
- return NoSTT()
100
- return provider_class(subscription_key=parts[0], region=parts[1])
101
- elif stt_engine == "flicker":
102
- return provider_class(api_key=api_key)
103
- else:
104
- return provider_class(api_key=api_key)
105
-
106
- except Exception as e:
107
- log_error("❌ Failed to create STT provider", e)
108
- return NoSTT()
109
-
110
- @staticmethod
111
- def get_available_providers():
112
- """Get list of available STT providers"""
113
- return list(stt_providers.keys()) + ["no_stt"]
114
-
115
- @staticmethod
116
- def _get_api_key(stt_config) -> Optional[str]:
117
- """Get decrypted API key"""
118
- if not stt_config.api_key:
119
- return None
120
-
121
- if stt_config.api_key.startswith("enc:"):
122
- from utils.encryption_utils import decrypt
123
- return decrypt(stt_config.api_key)
124
-
125
  return stt_config.api_key
 
1
+ """
2
+ STT Provider Factory for Flare
3
+ """
4
+ from typing import Optional
5
+ from .stt_interface import STTInterface, STTEngineType
6
+ from utils.logger import log_info, log_error, log_warning, log_debug
7
+ from config.config_provider import ConfigProvider
8
+
9
+ # Import providers conditionally
10
+ stt_providers = {}
11
+
12
+ try:
13
+ from .stt_google import GoogleSTT
14
+ stt_providers['google'] = GoogleSTT
15
+ except ImportError:
16
+ log_info("⚠️ Google Cloud STT not available")
17
+
18
+ try:
19
+ from .stt_deepgram import DeepgramSTT
20
+ stt_providers['deepgram'] = DeepgramSTT
21
+ except ImportError:
22
+ log_info("⚠️ Deepgram STT not available")
23
+
24
+ try:
25
+ from .stt_azure import AzureSTT
26
+ stt_providers['azure'] = AzureSTT
27
+ except ImportError:
28
+ log_error("⚠️ Azure STT not available")
29
+
30
+ try:
31
+ from .stt_flicker import FlickerSTT
32
+ stt_providers['flicker'] = FlickerSTT
33
+ except ImportError:
34
+ log_error("⚠️ Flicker STT not available")
35
+
36
+ class NoSTT(STTInterface):
37
+ """Dummy STT provider when STT is disabled"""
38
+
39
+ async def start_streaming(self, config) -> None:
40
+ pass
41
+
42
+ async def stream_audio(self, audio_chunk: bytes):
43
+ return
44
+ yield # Make it a generator
45
+
46
+ async def stop_streaming(self):
47
+ return None
48
+
49
+ def supports_realtime(self) -> bool:
50
+ return False
51
+
52
+ def get_supported_languages(self):
53
+ return []
54
+
55
+ def get_provider_name(self) -> str:
56
+ return "no_stt"
57
+
58
+ class STTFactory:
59
+ """Factory for creating STT providers"""
60
+
61
+ @staticmethod
62
+ def create_provider() -> Optional[STTInterface]:
63
+ """Create STT provider based on configuration"""
64
+ try:
65
+ cfg = ConfigProvider.get()
66
+ stt_provider_config = cfg.global_config.stt_provider
67
+ stt_engine = stt_provider_config.name
68
+
69
+ log_info(f"🎤 Creating STT provider: {stt_engine}")
70
+
71
+ if stt_engine == "no_stt":
72
+ return NoSTT()
73
+
74
+ # Get provider class
75
+ provider_class = stt_providers.get(stt_engine)
76
+ if not provider_class:
77
+ log_warning(f"⚠️ STT provider '{stt_engine}' not available")
78
+ return NoSTT()
79
+
80
+ # Get API key or credentials
81
+ api_key = STTFactory._get_api_key(stt_provider_config)
82
+
83
+ if not api_key and stt_provider_config.requires_api_key:
84
+ log_warning(f"⚠️ No API key configured for {stt_engine}")
85
+ return NoSTT()
86
+
87
+ # Create provider instance
88
+ if stt_engine == "google":
89
+ # For Google, api_key is the path to credentials JSON
90
+ return provider_class(credentials_path=api_key)
91
+ elif stt_engine == "deepgram":
92
+ return provider_class(api_key=api_key)
93
+ elif stt_engine == "azure":
94
+ # For Azure, parse the key format
95
+ parts = api_key.split('|')
96
+ if len(parts) != 2:
97
+ log_warning("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
98
+ return NoSTT()
99
+ return provider_class(subscription_key=parts[0], region=parts[1])
100
+ elif stt_engine == "flicker":
101
+ return provider_class(api_key=api_key)
102
+ else:
103
+ return provider_class(api_key=api_key)
104
+
105
+ except Exception as e:
106
+ log_error("❌ Failed to create STT provider", e)
107
+ return NoSTT()
108
+
109
+ @staticmethod
110
+ def get_available_providers():
111
+ """Get list of available STT providers"""
112
+ return list(stt_providers.keys()) + ["no_stt"]
113
+
114
+ @staticmethod
115
+ def _get_api_key(stt_config) -> Optional[str]:
116
+ """Get decrypted API key"""
117
+ if not stt_config.api_key:
118
+ return None
119
+
120
+ if stt_config.api_key.startswith("enc:"):
121
+ from utils.encryption_utils import decrypt
122
+ return decrypt(stt_config.api_key)
123
+
 
124
  return stt_config.api_key