ciyidogan commited on
Commit
913182c
Β·
verified Β·
1 Parent(s): f99d306

Update stt/stt_lifecycle_manager.py

Browse files
Files changed (1) hide show
  1. stt/stt_lifecycle_manager.py +91 -231
stt/stt_lifecycle_manager.py CHANGED
@@ -1,7 +1,7 @@
1
  """
2
- STT Lifecycle Manager for Flare
3
  ===============================
4
- Manages STT instances lifecycle per session
5
  """
6
  import asyncio
7
  from typing import Dict, Optional, Any
@@ -13,25 +13,34 @@ from chat_session.event_bus import EventBus, Event, EventType, publish_error
13
  from chat_session.resource_manager import ResourceManager, ResourceType
14
  from stt.stt_factory import STTFactory
15
  from stt.stt_interface import STTInterface, STTConfig, TranscriptionResult
 
16
  from utils.logger import log_info, log_error, log_debug, log_warning
17
 
18
 
19
  class STTSession:
20
- """STT session wrapper"""
21
 
22
  def __init__(self, session_id: str, stt_instance: STTInterface):
23
  self.session_id = session_id
24
  self.stt_instance = stt_instance
25
- self.is_streaming = False
26
  self.config: Optional[STTConfig] = None
27
  self.created_at = datetime.utcnow()
28
- self.last_activity = datetime.utcnow()
 
 
 
 
 
29
  self.total_chunks = 0
30
  self.total_bytes = 0
31
 
32
- def update_activity(self):
33
- """Update last activity timestamp"""
34
- self.last_activity = datetime.utcnow()
 
 
 
35
 
36
 
37
  class STTLifecycleManager:
@@ -51,29 +60,6 @@ class STTLifecycleManager:
51
  self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk)
52
  self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
53
 
54
- def _setup_resource_pool(self):
55
- """Setup STT instance pool"""
56
- self.resource_manager.register_pool(
57
- resource_type=ResourceType.STT_INSTANCE,
58
- factory=self._create_stt_instance,
59
- max_idle=5,
60
- max_age_seconds=300 # 5 minutes
61
- )
62
-
63
- async def _create_stt_instance(self) -> STTInterface:
64
- """Factory for creating STT instances"""
65
- try:
66
- stt_instance = STTFactory.create_provider()
67
- if not stt_instance:
68
- raise ValueError("Failed to create STT instance")
69
-
70
- log_debug("🎀 Created new STT instance")
71
- return stt_instance
72
-
73
- except Exception as e:
74
- log_error(f"❌ Failed to create STT instance", error=str(e))
75
- raise
76
-
77
  async def _handle_stt_start(self, event: Event):
78
  """Handle STT start request"""
79
  session_id = event.session_id
@@ -82,13 +68,8 @@ class STTLifecycleManager:
82
  try:
83
  log_info(f"🎀 Starting STT", session_id=session_id)
84
 
85
- # Check if already exists
86
- if session_id in self.stt_sessions:
87
- stt_session = self.stt_sessions[session_id]
88
- if stt_session.is_streaming:
89
- log_warning(f"⚠️ STT already streaming", session_id=session_id)
90
- return
91
- else:
92
  # Acquire STT instance from pool
93
  resource_id = f"stt_{session_id}"
94
  stt_instance = await self.resource_manager.acquire(
@@ -98,40 +79,28 @@ class STTLifecycleManager:
98
  cleanup_callback=self._cleanup_stt_instance
99
  )
100
 
101
- # Create session wrapper
102
  stt_session = STTSession(session_id, stt_instance)
103
  self.stt_sessions[session_id] = stt_session
 
 
 
104
 
105
- # Get session locale from state orchestrator
106
  locale = config_data.get("locale", "tr")
107
-
108
- # Build STT config - βœ… CONTINUOUS LISTENING Δ°Γ‡Δ°N AYARLAR
109
  stt_config = STTConfig(
110
  language=self._get_language_code(locale),
111
  sample_rate=config_data.get("sample_rate", 16000),
112
- encoding=config_data.get("encoding", "LINEAR16"), # Try "LINEAR16" if WEBM fails
113
  enable_punctuation=config_data.get("enable_punctuation", True),
114
- enable_word_timestamps=False,
115
  model=config_data.get("model", "latest_long"),
116
  use_enhanced=config_data.get("use_enhanced", True),
117
- single_utterance=False, # βœ… Continuous listening iΓ§in FALSE olmalΔ±
118
- interim_results=True, # βœ… Interim results'Δ± AΓ‡
119
  )
120
 
121
- # Log the exact config being used
122
- log_info(f"οΏ½οΏ½οΏ½οΏ½ STT Config: encoding={stt_config.encoding}, "
123
- f"sample_rate={stt_config.sample_rate}, "
124
- f"single_utterance={stt_config.single_utterance}, "
125
- f"interim_results={stt_config.interim_results}")
126
-
127
  stt_session.config = stt_config
 
128
 
129
- # Start streaming
130
- await stt_session.stt_instance.start_streaming(stt_config)
131
- stt_session.is_streaming = True
132
- stt_session.update_activity()
133
-
134
- log_info(f"βœ… STT started in continuous mode with interim results", session_id=session_id, language=stt_config.language)
135
 
136
  # Notify STT is ready
137
  await self.event_bus.publish(Event(
@@ -159,103 +128,45 @@ class STTLifecycleManager:
159
  error_message=f"Failed to start STT: {str(e)}"
160
  )
161
 
162
- async def _handle_stt_stop(self, event: Event):
163
- """Handle STT stop request"""
164
- session_id = event.session_id
165
- reason = event.data.get("reason", "unknown")
166
-
167
- log_info(f"πŸ›‘ Stopping STT", session_id=session_id, reason=reason)
168
-
169
- stt_session = self.stt_sessions.get(session_id)
170
- if not stt_session:
171
- log_warning(f"⚠️ No STT session found", session_id=session_id)
172
- return
173
-
174
- try:
175
- if stt_session.is_streaming:
176
- # Stop streaming
177
- final_result = await stt_session.stt_instance.stop_streaming()
178
- stt_session.is_streaming = False
179
-
180
- # If we got a final result, publish it
181
- if final_result and final_result.text:
182
- await self.event_bus.publish(Event(
183
- type=EventType.STT_RESULT,
184
- session_id=session_id,
185
- data={
186
- "text": final_result.text,
187
- "is_final": True,
188
- "confidence": final_result.confidence
189
- }
190
- ))
191
-
192
- # βœ… Send STT_STOPPED event (websocket_manager will handle it)
193
- await self.event_bus.publish(Event(
194
- type=EventType.STT_STOPPED,
195
- session_id=session_id,
196
- data={"reason": reason}
197
- ))
198
-
199
- # Don't remove session immediately - might restart
200
- stt_session.update_activity()
201
-
202
- log_info(f"βœ… STT stopped", session_id=session_id)
203
-
204
- except Exception as e:
205
- log_error(
206
- f"❌ Error stopping STT",
207
- session_id=session_id,
208
- error=str(e)
209
- )
210
-
211
  async def _handle_audio_chunk(self, event: Event):
212
- """Process audio chunk through STT"""
213
  session_id = event.session_id
214
 
215
  stt_session = self.stt_sessions.get(session_id)
216
- if not stt_session or not stt_session.is_streaming:
217
- # STT not ready, ignore chunk
218
  return
219
 
220
  try:
221
  # Decode audio data
222
  audio_data = base64.b64decode(event.data.get("audio_data", ""))
223
 
224
- # Update stats
 
225
  stt_session.total_chunks += 1
226
  stt_session.total_bytes += len(audio_data)
227
- stt_session.update_activity()
228
 
229
- # Stream to STT
230
- async for result in stt_session.stt_instance.stream_audio(audio_data):
231
- # Publish transcription results
 
 
 
 
 
232
  await self.event_bus.publish(Event(
233
- type=EventType.STT_RESULT,
234
  session_id=session_id,
235
- data={
236
- "text": result.text,
237
- "is_final": result.is_final,
238
- "confidence": result.confidence,
239
- "timestamp": result.timestamp
240
- }
241
  ))
242
 
243
- # Log final results
244
- if result.is_final:
245
- log_info(
246
- f"πŸ“ STT final result",
247
- session_id=session_id,
248
- text=result.text[:50] + "..." if len(result.text) > 50 else result.text,
249
- confidence=result.confidence
250
- )
251
-
252
  # Log progress periodically
253
  if stt_session.total_chunks % 100 == 0:
254
  log_debug(
255
  f"πŸ“Š STT progress",
256
  session_id=session_id,
257
  chunks=stt_session.total_chunks,
258
- bytes=stt_session.total_bytes
 
259
  )
260
 
261
  except Exception as e:
@@ -265,111 +176,60 @@ class STTLifecycleManager:
265
  error=str(e)
266
  )
267
 
268
- # Check if it's a recoverable error
269
- if "stream duration" in str(e) or "timeout" in str(e).lower():
270
- # STT timeout, restart needed
271
- await publish_error(
272
- session_id=session_id,
273
- error_type="stt_timeout",
274
- error_message="STT stream timeout, restart needed"
275
- )
276
- else:
277
- # Other STT error
278
- await publish_error(
279
- session_id=session_id,
280
- error_type="stt_error",
281
- error_message=str(e)
282
- )
283
-
284
- async def _handle_session_ended(self, event: Event):
285
- """Clean up STT resources when session ends"""
286
  session_id = event.session_id
287
- await self._cleanup_session(session_id)
288
-
289
- async def _cleanup_session(self, session_id: str):
290
- """Clean up STT session"""
291
- stt_session = self.stt_sessions.pop(session_id, None)
292
  if not stt_session:
 
293
  return
294
-
295
  try:
296
- # Stop streaming if active
297
- if stt_session.is_streaming:
298
- await stt_session.stt_instance.stop_streaming()
299
-
300
- # Release resource
301
- resource_id = f"stt_{session_id}"
302
- await self.resource_manager.release(resource_id, delay_seconds=60)
303
-
304
- log_info(
305
- f"🧹 STT session cleaned up",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  session_id=session_id,
307
- total_chunks=stt_session.total_chunks,
308
- total_bytes=stt_session.total_bytes
309
- )
310
-
 
311
  except Exception as e:
312
  log_error(
313
- f"❌ Error cleaning up STT session",
314
  session_id=session_id,
315
  error=str(e)
316
- )
317
-
318
- async def _cleanup_stt_instance(self, stt_instance: STTInterface):
319
- """Cleanup callback for STT instance"""
320
- try:
321
- # Ensure streaming is stopped
322
- if hasattr(stt_instance, 'is_streaming') and stt_instance.is_streaming:
323
- await stt_instance.stop_streaming()
324
-
325
- log_debug("🧹 STT instance cleaned up")
326
-
327
- except Exception as e:
328
- log_error(f"❌ Error cleaning up STT instance", error=str(e))
329
-
330
- def _get_language_code(self, locale: str) -> str:
331
- """Convert locale to STT language code"""
332
- # Map common locales to STT language codes
333
- locale_map = {
334
- "tr": "tr-TR",
335
- "en": "en-US",
336
- "de": "de-DE",
337
- "fr": "fr-FR",
338
- "es": "es-ES",
339
- "it": "it-IT",
340
- "pt": "pt-BR",
341
- "ru": "ru-RU",
342
- "ja": "ja-JP",
343
- "ko": "ko-KR",
344
- "zh": "zh-CN",
345
- "ar": "ar-SA"
346
- }
347
-
348
- # Check direct match
349
- if locale in locale_map:
350
- return locale_map[locale]
351
-
352
- # Check if it's already a full code
353
- if "-" in locale and len(locale) == 5:
354
- return locale
355
-
356
- # Default to locale-LOCALE format
357
- return f"{locale}-{locale.upper()}"
358
-
359
- def get_stats(self) -> Dict[str, Any]:
360
- """Get STT manager statistics"""
361
- session_stats = {}
362
- for session_id, stt_session in self.stt_sessions.items():
363
- session_stats[session_id] = {
364
- "is_streaming": stt_session.is_streaming,
365
- "total_chunks": stt_session.total_chunks,
366
- "total_bytes": stt_session.total_bytes,
367
- "uptime_seconds": (datetime.utcnow() - stt_session.created_at).total_seconds(),
368
- "last_activity": stt_session.last_activity.isoformat()
369
- }
370
-
371
- return {
372
- "active_sessions": len(self.stt_sessions),
373
- "streaming_sessions": sum(1 for s in self.stt_sessions.values() if s.is_streaming),
374
- "sessions": session_stats
375
- }
 
1
  """
2
+ STT Lifecycle Manager for Flare - Batch Mode
3
  ===============================
4
+ Manages STT instances and audio collection
5
  """
6
  import asyncio
7
  from typing import Dict, Optional, Any
 
13
  from chat_session.resource_manager import ResourceManager, ResourceType
14
  from stt.stt_factory import STTFactory
15
  from stt.stt_interface import STTInterface, STTConfig, TranscriptionResult
16
+ from stt.voice_activity_detector import VoiceActivityDetector
17
  from utils.logger import log_info, log_error, log_debug, log_warning
18
 
19
 
20
  class STTSession:
21
+ """STT session with audio collection"""
22
 
23
  def __init__(self, session_id: str, stt_instance: STTInterface):
24
  self.session_id = session_id
25
  self.stt_instance = stt_instance
26
+ self.is_active = False
27
  self.config: Optional[STTConfig] = None
28
  self.created_at = datetime.utcnow()
29
+
30
+ # Audio collection
31
+ self.audio_buffer = []
32
+ self.vad = VoiceActivityDetector()
33
+
34
+ # Stats
35
  self.total_chunks = 0
36
  self.total_bytes = 0
37
 
38
+ def reset(self):
39
+ """Reset session for new utterance"""
40
+ self.audio_buffer = []
41
+ self.vad.reset()
42
+ self.total_chunks = 0
43
+ self.total_bytes = 0
44
 
45
 
46
  class STTLifecycleManager:
 
60
  self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk)
61
  self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  async def _handle_stt_start(self, event: Event):
64
  """Handle STT start request"""
65
  session_id = event.session_id
 
68
  try:
69
  log_info(f"🎀 Starting STT", session_id=session_id)
70
 
71
+ # Get or create session
72
+ if session_id not in self.stt_sessions:
 
 
 
 
 
73
  # Acquire STT instance from pool
74
  resource_id = f"stt_{session_id}"
75
  stt_instance = await self.resource_manager.acquire(
 
79
  cleanup_callback=self._cleanup_stt_instance
80
  )
81
 
82
+ # Create session
83
  stt_session = STTSession(session_id, stt_instance)
84
  self.stt_sessions[session_id] = stt_session
85
+ else:
86
+ stt_session = self.stt_sessions[session_id]
87
+ stt_session.reset()
88
 
89
+ # Build STT config
90
  locale = config_data.get("locale", "tr")
 
 
91
  stt_config = STTConfig(
92
  language=self._get_language_code(locale),
93
  sample_rate=config_data.get("sample_rate", 16000),
94
+ encoding=config_data.get("encoding", "LINEAR16"),
95
  enable_punctuation=config_data.get("enable_punctuation", True),
 
96
  model=config_data.get("model", "latest_long"),
97
  use_enhanced=config_data.get("use_enhanced", True),
 
 
98
  )
99
 
 
 
 
 
 
 
100
  stt_session.config = stt_config
101
+ stt_session.is_active = True
102
 
103
+ log_info(f"βœ… STT started in batch mode", session_id=session_id, language=stt_config.language)
 
 
 
 
 
104
 
105
  # Notify STT is ready
106
  await self.event_bus.publish(Event(
 
128
  error_message=f"Failed to start STT: {str(e)}"
129
  )
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  async def _handle_audio_chunk(self, event: Event):
132
+ """Process audio chunk through VAD and collect"""
133
  session_id = event.session_id
134
 
135
  stt_session = self.stt_sessions.get(session_id)
136
+ if not stt_session or not stt_session.is_active:
 
137
  return
138
 
139
  try:
140
  # Decode audio data
141
  audio_data = base64.b64decode(event.data.get("audio_data", ""))
142
 
143
+ # Add to buffer
144
+ stt_session.audio_buffer.append(audio_data)
145
  stt_session.total_chunks += 1
146
  stt_session.total_bytes += len(audio_data)
 
147
 
148
+ # Process through VAD
149
+ is_speech, silence_duration_ms = stt_session.vad.process_chunk(audio_data)
150
+
151
+ # Check if utterance ended (silence threshold reached)
152
+ if not is_speech and silence_duration_ms >= 2000: # 2 seconds of silence
153
+ log_info(f"πŸ’¬ Utterance ended after {silence_duration_ms}ms silence", session_id=session_id)
154
+
155
+ # Stop STT to trigger transcription
156
  await self.event_bus.publish(Event(
157
+ type=EventType.STT_STOPPED,
158
  session_id=session_id,
159
+ data={"reason": "silence_detected"}
 
 
 
 
 
160
  ))
161
 
 
 
 
 
 
 
 
 
 
162
  # Log progress periodically
163
  if stt_session.total_chunks % 100 == 0:
164
  log_debug(
165
  f"πŸ“Š STT progress",
166
  session_id=session_id,
167
  chunks=stt_session.total_chunks,
168
+ bytes=stt_session.total_bytes,
169
+ vad_stats=stt_session.vad.get_stats()
170
  )
171
 
172
  except Exception as e:
 
176
  error=str(e)
177
  )
178
 
179
+ async def _handle_stt_stop(self, event: Event):
180
+ """Handle STT stop request and perform transcription"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  session_id = event.session_id
182
+ reason = event.data.get("reason", "unknown")
183
+
184
+ log_info(f"πŸ›‘ Stopping STT", session_id=session_id, reason=reason)
185
+
186
+ stt_session = self.stt_sessions.get(session_id)
187
  if not stt_session:
188
+ log_warning(f"⚠️ No STT session found", session_id=session_id)
189
  return
190
+
191
  try:
192
+ if stt_session.is_active and stt_session.audio_buffer:
193
+ # Combine audio chunks
194
+ combined_audio = b''.join(stt_session.audio_buffer)
195
+
196
+ # Transcribe using batch mode
197
+ log_info(f"πŸ“ Transcribing {len(combined_audio)} bytes of audio", session_id=session_id)
198
+ result = await stt_session.stt_instance.transcribe(
199
+ audio_data=combined_audio,
200
+ config=stt_session.config
201
+ )
202
+
203
+ # Publish result if we got transcription
204
+ if result and result.text:
205
+ await self.event_bus.publish(Event(
206
+ type=EventType.STT_RESULT,
207
+ session_id=session_id,
208
+ data={
209
+ "text": result.text,
210
+ "is_final": True,
211
+ "confidence": result.confidence
212
+ }
213
+ ))
214
+ else:
215
+ log_warning(f"⚠️ No transcription result", session_id=session_id)
216
+
217
+ # Mark as inactive and reset
218
+ stt_session.is_active = False
219
+ stt_session.reset()
220
+
221
+ # Send STT_STOPPED event
222
+ await self.event_bus.publish(Event(
223
+ type=EventType.STT_STOPPED,
224
  session_id=session_id,
225
+ data={"reason": reason}
226
+ ))
227
+
228
+ log_info(f"βœ… STT stopped", session_id=session_id)
229
+
230
  except Exception as e:
231
  log_error(
232
+ f"❌ Error stopping STT",
233
  session_id=session_id,
234
  error=str(e)
235
+ )