ciyidogan commited on
Commit
321ebc6
·
verified ·
1 Parent(s): 7bdc4b2

Update stt_factory.py

Browse files
Files changed (1) hide show
  1. stt_factory.py +89 -48
stt_factory.py CHANGED
@@ -5,59 +5,100 @@ from typing import Optional
5
  from stt_interface import STTInterface, STTEngineType, log
6
  from stt_google import GoogleCloudSTT
7
  from config_provider import ConfigProvider
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class STTFactory:
10
- """Factory for creating STT provider instances"""
11
 
12
  @staticmethod
13
  def create_provider() -> Optional[STTInterface]:
14
  """Create STT provider based on configuration"""
15
- cfg = ConfigProvider.get()
16
- stt_config = cfg.global_config.stt_provider
17
-
18
- if not stt_config or stt_config.name == "no_stt":
19
- log("🔇 No STT provider configured")
20
- return None
21
-
22
- provider_name = stt_config.name
23
- log(f"🏭 Creating STT provider: {provider_name}")
24
-
25
- # Get provider definition
26
- provider_def = cfg.global_config.get_provider_config("stt", provider_name)
27
- if not provider_def:
28
- log(f"⚠️ Unknown STT provider: {provider_name}")
29
- return None
30
-
31
- # Get credentials/API key
32
- credentials = STTFactory._get_credentials(stt_config)
33
- if not credentials and provider_def.requires_api_key:
34
- log(f"⚠️ No credentials for STT provider: {provider_name}")
35
- return None
36
-
37
- # Create provider based on name
38
- if provider_name == "google":
39
- return GoogleCloudSTT(credentials)
40
- elif provider_name == "azure":
41
- log("⚠️ Azure STT not implemented yet")
42
- return None
43
- elif provider_name == "amazon":
44
- log("⚠️ Amazon STT not implemented yet")
45
- return None
46
- elif provider_name == "flicker":
47
- log("⚠️ Flicker STT not implemented yet")
48
- return None
49
- else:
50
- log(f"⚠️ Unsupported STT provider: {provider_name}")
51
- return None
 
 
 
 
52
 
53
  @staticmethod
54
- def _get_credentials(stt_config) -> Optional[str]:
55
- """Get decrypted credentials/API key"""
56
- if not stt_config.api_key:
57
- return None
58
-
59
- if stt_config.api_key.startswith("enc:"):
60
- from encryption_utils import decrypt
61
- return decrypt(stt_config.api_key)
62
-
63
- return stt_config.api_key
 
5
  from stt_interface import STTInterface, STTEngineType, log
6
  from stt_google import GoogleCloudSTT
7
  from config_provider import ConfigProvider
8
+ from stt_interface import STTInterface
9
 
10
+ # Import providers conditionally
11
+ stt_providers = {}
12
+
13
+ try:
14
+ from stt_google import GoogleCloudSTT
15
+ stt_providers['google'] = GoogleCloudSTT
16
+ except ImportError:
17
+ log("⚠️ Google Cloud STT not available")
18
+
19
+ try:
20
+ from stt_azure import AzureSTT
21
+ stt_providers['azure'] = AzureSTT
22
+ except ImportError:
23
+ log("⚠️ Azure STT not available")
24
+
25
+ try:
26
+ from stt_flicker import FlickerSTT
27
+ stt_providers['flicker'] = FlickerSTT
28
+ except ImportError:
29
+ log("⚠️ Flicker STT not available")
30
+
31
+ class NoSTT(STTInterface):
32
+ """Dummy STT provider when STT is disabled"""
33
+
34
+ async def start_streaming(self, config) -> None:
35
+ pass
36
+
37
+ async def stream_audio(self, audio_chunk: bytes):
38
+ return
39
+ yield # Make it a generator
40
+
41
+ async def stop_streaming(self):
42
+ return None
43
+
44
+ def supports_realtime(self) -> bool:
45
+ return False
46
+
47
+ def get_supported_languages(self):
48
+ return []
49
+
50
+ def get_provider_name(self) -> str:
51
+ return "no_stt"
52
+
53
  class STTFactory:
54
+ """Factory for creating STT providers"""
55
 
56
  @staticmethod
57
  def create_provider() -> Optional[STTInterface]:
58
  """Create STT provider based on configuration"""
59
+ try:
60
+ cfg = ConfigProvider.get()
61
+ stt_engine = cfg.global_config.stt_engine
62
+
63
+ log(f"🎤 Creating STT provider: {stt_engine}")
64
+
65
+ if stt_engine == "no_stt":
66
+ return NoSTT()
67
+
68
+ # Get provider class
69
+ provider_class = stt_providers.get(stt_engine)
70
+ if not provider_class:
71
+ log(f"⚠️ STT provider '{stt_engine}' not available")
72
+ return NoSTT()
73
+
74
+ # Get API key or credentials
75
+ api_key = cfg.global_config.get_stt_api_key()
76
+
77
+ if not api_key:
78
+ log(f"⚠️ No API key configured for {stt_engine}")
79
+ return NoSTT()
80
+
81
+ # Create provider instance
82
+ if stt_engine == "google":
83
+ # For Google, api_key is the path to credentials JSON
84
+ return provider_class(credentials_path=api_key)
85
+ elif stt_engine == "azure":
86
+ # For Azure, parse the key format
87
+ parts = api_key.split('|')
88
+ if len(parts) != 2:
89
+ log("⚠️ Invalid Azure STT key format. Expected: subscription_key|region")
90
+ return NoSTT()
91
+ return provider_class(subscription_key=parts[0], region=parts[1])
92
+ elif stt_engine == "flicker":
93
+ return provider_class(api_key=api_key)
94
+ else:
95
+ return provider_class(api_key=api_key)
96
+
97
+ except Exception as e:
98
+ log(f"❌ Failed to create STT provider: {e}")
99
+ return NoSTT()
100
 
101
  @staticmethod
102
+ def get_available_providers():
103
+ """Get list of available STT providers"""
104
+ return list(stt_providers.keys()) + ["no_stt"]