ciyidogan commited on
Commit
890fc3a
·
verified ·
1 Parent(s): f91801d

Upload websocket-handler.py

Browse files
Files changed (1) hide show
  1. websocket-handler.py +579 -0
websocket-handler.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket Handler for Real-time STT/TTS with Barge-in Support
3
+ """
4
+ from fastapi import WebSocket, WebSocketDisconnect
5
+ from typing import Dict, Any, Optional
6
+ import json
7
+ import asyncio
8
+ import base64
9
+ from datetime import datetime
10
+ from collections import deque
11
+ from enum import Enum
12
+ import numpy as np
13
+ import traceback
14
+
15
+ from session import Session, session_store
16
+ from config_provider import ConfigProvider
17
+ from chat_handler import handle_new_message, handle_parameter_followup
18
+ from stt_factory import STTFactory
19
+ from tts_factory import TTSFactory
20
+ from logger import log_info, log_error, log_debug, log_warning
21
+
22
+ # ========================= CONSTANTS =========================
23
+ # Default values - will be overridden by config
24
+ DEFAULT_SILENCE_THRESHOLD_MS = 2000
25
+ DEFAULT_AUDIO_CHUNK_SIZE = 4096
26
+ DEFAULT_ENERGY_THRESHOLD = 0.01
27
+ DEFAULT_AUDIO_BUFFER_MAX_SIZE = 1000
28
+
29
+ # ========================= ENUMS =========================
30
+ class ConversationState(Enum):
31
+ IDLE = "idle"
32
+ LISTENING = "listening"
33
+ PROCESSING_STT = "processing_stt"
34
+ PROCESSING_LLM = "processing_llm"
35
+ PROCESSING_TTS = "processing_tts"
36
+ PLAYING_AUDIO = "playing_audio"
37
+
38
+ # ========================= CLASSES =========================
39
+ class AudioBuffer:
40
+ """Thread-safe circular buffer for audio chunks"""
41
+ def __init__(self, max_size: int = AUDIO_BUFFER_MAX_SIZE):
42
+ self.buffer = deque(maxlen=max_size)
43
+ self.lock = asyncio.Lock()
44
+
45
+ async def add_chunk(self, chunk_data: str):
46
+ """Add base64 encoded audio chunk"""
47
+ async with self.lock:
48
+ decoded = base64.b64decode(chunk_data)
49
+ self.buffer.append(decoded)
50
+
51
+ async def get_all_audio(self) -> bytes:
52
+ """Get all audio data concatenated"""
53
+ async with self.lock:
54
+ return b''.join(self.buffer)
55
+
56
+ async def clear(self):
57
+ """Clear buffer"""
58
+ async with self.lock:
59
+ self.buffer.clear()
60
+
61
+ def size(self) -> int:
62
+ """Get current buffer size"""
63
+ return len(self.buffer)
64
+
65
+
66
+ class SilenceDetector:
67
+ """Detect silence in audio stream"""
68
+ def __init__(self, threshold_ms: int = SILENCE_THRESHOLD_MS, energy_threshold: float = ENERGY_THRESHOLD):
69
+ self.threshold_ms = threshold_ms
70
+ self.energy_threshold = energy_threshold
71
+ self.silence_start = None
72
+ self.sample_rate = 16000
73
+
74
+ def update(self, audio_chunk: bytes) -> int:
75
+ """Update with new audio chunk and return silence duration in ms"""
76
+ if self.is_silence(audio_chunk):
77
+ if self.silence_start is None:
78
+ self.silence_start = datetime.now()
79
+ silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000
80
+ return int(silence_duration)
81
+ else:
82
+ self.silence_start = None
83
+ return 0
84
+
85
+ def is_silence(self, audio_chunk: bytes) -> bool:
86
+ """Check if audio chunk is silence"""
87
+ try:
88
+ # Convert bytes to numpy array (assuming 16-bit PCM)
89
+ audio_data = np.frombuffer(audio_chunk, dtype=np.int16)
90
+
91
+ # Calculate RMS energy
92
+ if len(audio_data) == 0:
93
+ return True
94
+
95
+ rms = np.sqrt(np.mean(audio_data.astype(float) ** 2))
96
+ normalized_rms = rms / 32768.0 # Normalize for 16-bit audio
97
+
98
+ return normalized_rms < self.energy_threshold
99
+ except Exception as e:
100
+ log_warning(f"Silence detection error: {e}")
101
+ return False
102
+
103
+ def reset(self):
104
+ """Reset silence detection"""
105
+ self.silence_start = None
106
+
107
+
108
+ class BargeInHandler:
109
+ """Handle user interruptions during TTS playback"""
110
+ def __init__(self):
111
+ self.active_tts_task: Optional[asyncio.Task] = None
112
+ self.is_interrupting = False
113
+ self.lock = asyncio.Lock()
114
+
115
+ async def start_tts_task(self, coro):
116
+ """Start a cancellable TTS task"""
117
+ async with self.lock:
118
+ # Cancel any existing task
119
+ if self.active_tts_task and not self.active_tts_task.done():
120
+ self.active_tts_task.cancel()
121
+ try:
122
+ await self.active_tts_task
123
+ except asyncio.CancelledError:
124
+ pass
125
+
126
+ # Start new task
127
+ self.active_tts_task = asyncio.create_task(coro)
128
+ return self.active_tts_task
129
+
130
+ async def handle_interruption(self, current_state: ConversationState):
131
+ """Handle barge-in interruption"""
132
+ async with self.lock:
133
+ self.is_interrupting = True
134
+
135
+ # Cancel TTS if active
136
+ if self.active_tts_task and not self.active_tts_task.done():
137
+ log_info("Barge-in: Cancelling active TTS")
138
+ self.active_tts_task.cancel()
139
+ try:
140
+ await self.active_tts_task
141
+ except asyncio.CancelledError:
142
+ pass
143
+
144
+ # Reset flag after short delay
145
+ await asyncio.sleep(0.5)
146
+ self.is_interrupting = False
147
+
148
+
149
+ class RealtimeSession:
150
+ """Manage a real-time conversation session"""
151
+ def __init__(self, session: Session):
152
+ self.session = session
153
+ self.state = ConversationState.IDLE
154
+
155
+ # Get settings from config
156
+ config = ConfigProvider.get().global_config.stt_provider.settings
157
+
158
+ # Initialize with config values or defaults
159
+ silence_threshold = config.get("speech_timeout_ms", DEFAULT_SILENCE_THRESHOLD_MS)
160
+ energy_threshold = config.get("energy_threshold", DEFAULT_ENERGY_THRESHOLD)
161
+ buffer_max_size = config.get("audio_buffer_max_size", DEFAULT_AUDIO_BUFFER_MAX_SIZE)
162
+
163
+ self.audio_buffer = AudioBuffer(max_size=buffer_max_size)
164
+ self.silence_detector = SilenceDetector(
165
+ threshold_ms=silence_threshold,
166
+ energy_threshold=energy_threshold
167
+ )
168
+ self.barge_in_handler = BargeInHandler()
169
+ self.stt_manager = None
170
+ self.current_transcription = ""
171
+ self.is_streaming = False
172
+ self.lock = asyncio.Lock()
173
+
174
+ # Store config for later use
175
+ self.audio_chunk_size = config.get("audio_chunk_size", DEFAULT_AUDIO_CHUNK_SIZE)
176
+ self.silence_threshold_ms = silence_threshold
177
+
178
+ async def initialize_stt(self):
179
+ """Initialize STT provider"""
180
+ try:
181
+ self.stt_manager = STTFactory.create_provider()
182
+ if self.stt_manager:
183
+ config = ConfigProvider.get().global_config.stt_provider.settings
184
+ await self.stt_manager.start_streaming({
185
+ "language": config.get("language", "tr-TR"),
186
+ "interim_results": config.get("interim_results", True),
187
+ "single_utterance": False,
188
+ "enable_punctuation": config.get("enable_punctuation", True),
189
+ "sample_rate": 16000,
190
+ "encoding": "WEBM_OPUS"
191
+ })
192
+ log_info("STT manager initialized", session_id=self.session.session_id)
193
+ return True
194
+ except Exception as e:
195
+ log_error(f"Failed to initialize STT", error=str(e), session_id=self.session.session_id)
196
+ return False
197
+
198
+ async def change_state(self, new_state: ConversationState):
199
+ """Change conversation state"""
200
+ async with self.lock:
201
+ old_state = self.state
202
+ self.state = new_state
203
+ log_debug(
204
+ f"State change: {old_state.value} → {new_state.value}",
205
+ session_id=self.session.session_id
206
+ )
207
+
208
+ async def handle_barge_in(self):
209
+ """Handle user interruption"""
210
+ await self.barge_in_handler.handle_interruption(self.state)
211
+ await self.change_state(ConversationState.LISTENING)
212
+
213
+ async def reset_for_new_utterance(self):
214
+ """Reset for new user utterance"""
215
+ await self.audio_buffer.clear()
216
+ self.silence_detector.reset()
217
+ self.current_transcription = ""
218
+
219
+ async def cleanup(self):
220
+ """Clean up resources"""
221
+ try:
222
+ if self.stt_manager:
223
+ await self.stt_manager.stop_streaming()
224
+ log_info(f"Cleaned up realtime session", session_id=self.session.session_id)
225
+ except Exception as e:
226
+ log_warning(f"Cleanup error", error=str(e), session_id=self.session.session_id)
227
+
228
+
229
+ # ========================= MAIN HANDLER =========================
230
+ async def websocket_endpoint(websocket: WebSocket, session_id: str):
231
+ """Main WebSocket endpoint for real-time conversation"""
232
+ await websocket.accept()
233
+ log_info(f"WebSocket connected", session_id=session_id)
234
+
235
+ # Get session
236
+ session = session_store.get_session(session_id)
237
+ if not session:
238
+ await websocket.send_json({
239
+ "type": "error",
240
+ "message": "Session not found"
241
+ })
242
+ await websocket.close()
243
+ return
244
+
245
+ # Mark as realtime session
246
+ session.is_realtime_session = True
247
+ session_store.update_session(session)
248
+
249
+ # Initialize conversation
250
+ realtime_session = RealtimeSession(session)
251
+
252
+ # Initialize STT
253
+ stt_initialized = await realtime_session.initialize_stt()
254
+ if not stt_initialized:
255
+ await websocket.send_json({
256
+ "type": "error",
257
+ "message": "STT initialization failed"
258
+ })
259
+
260
+ try:
261
+ while True:
262
+ # Receive message
263
+ message = await websocket.receive_json()
264
+ message_type = message.get("type")
265
+
266
+ if message_type == "audio_chunk":
267
+ await handle_audio_chunk(websocket, realtime_session, message)
268
+
269
+ elif message_type == "control":
270
+ await handle_control_message(websocket, realtime_session, message)
271
+
272
+ elif message_type == "ping":
273
+ # Keep-alive ping
274
+ await websocket.send_json({"type": "pong"})
275
+
276
+ except WebSocketDisconnect:
277
+ log_info(f"WebSocket disconnected", session_id=session_id)
278
+ except Exception as e:
279
+ log_error(
280
+ f"WebSocket error",
281
+ error=str(e),
282
+ traceback=traceback.format_exc(),
283
+ session_id=session_id
284
+ )
285
+ await websocket.send_json({
286
+ "type": "error",
287
+ "message": str(e)
288
+ })
289
+ finally:
290
+ await realtime_session.cleanup()
291
+
292
+
293
+ # ========================= MESSAGE HANDLERS =========================
294
+ async def handle_audio_chunk(websocket: WebSocket, session: RealtimeSession, message: Dict[str, Any]):
295
+ """Handle incoming audio chunk with barge-in support"""
296
+ try:
297
+ audio_data = message.get("data")
298
+ if not audio_data:
299
+ return
300
+
301
+ # Check for barge-in during TTS/audio playback
302
+ if session.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]:
303
+ await session.handle_barge_in()
304
+ await websocket.send_json({
305
+ "type": "control",
306
+ "action": "stop_playback"
307
+ })
308
+ log_info(f"Barge-in detected", session_id=session.session.session_id, state=session.state.value)
309
+
310
+ # Change state to listening if idle
311
+ if session.state == ConversationState.IDLE:
312
+ await session.change_state(ConversationState.LISTENING)
313
+ await websocket.send_json({
314
+ "type": "state_change",
315
+ "from": "idle",
316
+ "to": "listening"
317
+ })
318
+
319
+ # Add to buffer - don't lose any audio
320
+ await session.audio_buffer.add_chunk(audio_data)
321
+
322
+ # Decode for processing
323
+ decoded_audio = base64.b64decode(audio_data)
324
+
325
+ # Check silence
326
+ silence_duration = session.silence_detector.update(decoded_audio)
327
+
328
+ # Stream to STT if available
329
+ if session.stt_manager and session.state == ConversationState.LISTENING:
330
+ async for result in session.stt_manager.stream_audio(decoded_audio):
331
+ # Send transcription updates
332
+ await websocket.send_json({
333
+ "type": "transcription",
334
+ "text": result.text,
335
+ "is_final": result.is_final,
336
+ "confidence": result.confidence
337
+ })
338
+
339
+ if result.is_final:
340
+ session.current_transcription = result.text
341
+
342
+ # Process if silence detected and we have transcription
343
+ if silence_duration > session.silence_threshold_ms and session.current_transcription:
344
+ log_info(
345
+ f"User stopped speaking",
346
+ session_id=session.session.session_id,
347
+ silence_ms=silence_duration,
348
+ text=session.current_transcription
349
+ )
350
+ await process_user_input(websocket, session)
351
+
352
+ except Exception as e:
353
+ log_error(
354
+ f"Audio chunk handling error",
355
+ error=str(e),
356
+ traceback=traceback.format_exc(),
357
+ session_id=session.session.session_id
358
+ )
359
+ await websocket.send_json({
360
+ "type": "error",
361
+ "message": f"Audio processing error: {str(e)}"
362
+ })
363
+
364
+
365
+ async def handle_control_message(websocket: WebSocket, session: RealtimeSession, message: Dict[str, Any]):
366
+ """Handle control messages"""
367
+ action = message.get("action")
368
+ config = message.get("config", {})
369
+
370
+ log_debug(f"Control message", action=action, session_id=session.session.session_id)
371
+
372
+ if action == "start_session":
373
+ # Session configuration
374
+ await websocket.send_json({
375
+ "type": "session_started",
376
+ "session_id": session.session.session_id,
377
+ "config": {
378
+ "silence_threshold_ms": session.silence_threshold_ms,
379
+ "audio_chunk_size": session.audio_chunk_size,
380
+ "supports_barge_in": True
381
+ }
382
+ })
383
+
384
+ elif action == "end_session":
385
+ # Clean up and close
386
+ await session.cleanup()
387
+ await websocket.close()
388
+
389
+ elif action == "interrupt":
390
+ # Handle explicit interrupt
391
+ await session.handle_barge_in()
392
+ await websocket.send_json({
393
+ "type": "control",
394
+ "action": "interrupt_acknowledged"
395
+ })
396
+
397
+ elif action == "reset":
398
+ # Reset conversation state
399
+ await session.reset_for_new_utterance()
400
+ await session.change_state(ConversationState.IDLE)
401
+ await websocket.send_json({
402
+ "type": "state_change",
403
+ "from": session.state.value,
404
+ "to": "idle"
405
+ })
406
+
407
+ elif action == "audio_ended":
408
+ # Audio playback ended on client
409
+ if session.state == ConversationState.PLAYING_AUDIO:
410
+ await session.change_state(ConversationState.IDLE)
411
+ await websocket.send_json({
412
+ "type": "state_change",
413
+ "from": "playing_audio",
414
+ "to": "idle"
415
+ })
416
+
417
+
418
+ # ========================= PROCESSING FUNCTIONS =========================
419
+ async def process_user_input(websocket: WebSocket, session: RealtimeSession):
420
+ """Process complete user input"""
421
+ try:
422
+ user_text = session.current_transcription
423
+ if not user_text:
424
+ await session.reset_for_new_utterance()
425
+ await session.change_state(ConversationState.IDLE)
426
+ return
427
+
428
+ log_info(f"Processing user input", text=user_text, session_id=session.session.session_id)
429
+
430
+ # State: STT Processing
431
+ await session.change_state(ConversationState.PROCESSING_STT)
432
+ await websocket.send_json({
433
+ "type": "state_change",
434
+ "from": "listening",
435
+ "to": "processing_stt"
436
+ })
437
+
438
+ # Send final transcription
439
+ await websocket.send_json({
440
+ "type": "transcription",
441
+ "text": user_text,
442
+ "is_final": True,
443
+ "confidence": 0.95
444
+ })
445
+
446
+ # State: LLM Processing
447
+ await session.change_state(ConversationState.PROCESSING_LLM)
448
+ await websocket.send_json({
449
+ "type": "state_change",
450
+ "from": "processing_stt",
451
+ "to": "processing_llm"
452
+ })
453
+
454
+ # Add to chat history
455
+ session.session.add_message("user", user_text)
456
+
457
+ # Get LLM response based on session state
458
+ if session.session.state == "collect_params":
459
+ response_text = await handle_parameter_followup(session.session, user_text)
460
+ else:
461
+ response_text = await handle_new_message(session.session, user_text)
462
+
463
+ # Add response to history
464
+ session.session.add_message("assistant", response_text)
465
+
466
+ # Send text response
467
+ await websocket.send_json({
468
+ "type": "assistant_response",
469
+ "text": response_text
470
+ })
471
+
472
+ # Generate TTS if enabled
473
+ tts_provider = TTSFactory.create_provider()
474
+ if tts_provider:
475
+ await session.change_state(ConversationState.PROCESSING_TTS)
476
+ await websocket.send_json({
477
+ "type": "state_change",
478
+ "from": "processing_llm",
479
+ "to": "processing_tts"
480
+ })
481
+
482
+ # Generate TTS with barge-in support
483
+ tts_task = session.barge_in_handler.start_tts_task(
484
+ generate_and_stream_tts(websocket, session, tts_provider, response_text)
485
+ )
486
+
487
+ try:
488
+ await tts_task
489
+ except asyncio.CancelledError:
490
+ log_info("TTS cancelled due to barge-in", session_id=session.session.session_id)
491
+ else:
492
+ # No TTS, go back to idle
493
+ await session.change_state(ConversationState.IDLE)
494
+ await websocket.send_json({
495
+ "type": "state_change",
496
+ "from": "processing_llm",
497
+ "to": "idle"
498
+ })
499
+
500
+ # Reset for next input
501
+ await session.reset_for_new_utterance()
502
+
503
+ except Exception as e:
504
+ log_error(
505
+ f"Error processing user input",
506
+ error=str(e),
507
+ traceback=traceback.format_exc(),
508
+ session_id=session.session.session_id
509
+ )
510
+ await websocket.send_json({
511
+ "type": "error",
512
+ "message": f"Processing error: {str(e)}"
513
+ })
514
+ await session.reset_for_new_utterance()
515
+ await session.change_state(ConversationState.IDLE)
516
+
517
+
518
+ async def generate_and_stream_tts(
519
+ websocket: WebSocket,
520
+ session: RealtimeSession,
521
+ tts_provider,
522
+ text: str
523
+ ):
524
+ """Generate and stream TTS audio with cancellation support"""
525
+ try:
526
+ # Generate audio
527
+ audio_data = await tts_provider.synthesize(text)
528
+
529
+ # Change state to playing
530
+ await session.change_state(ConversationState.PLAYING_AUDIO)
531
+ await websocket.send_json({
532
+ "type": "state_change",
533
+ "from": "processing_tts",
534
+ "to": "playing_audio"
535
+ })
536
+
537
+ # Stream audio in chunks
538
+ chunk_size = session.audio_chunk_size
539
+ total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size
540
+
541
+ for i in range(0, len(audio_data), chunk_size):
542
+ # Check for cancellation
543
+ if asyncio.current_task().cancelled():
544
+ break
545
+
546
+ chunk = audio_data[i:i + chunk_size]
547
+ chunk_index = i // chunk_size
548
+
549
+ await websocket.send_json({
550
+ "type": "tts_audio",
551
+ "data": base64.b64encode(chunk).decode('utf-8'),
552
+ "chunk_index": chunk_index,
553
+ "total_chunks": total_chunks,
554
+ "is_last": chunk_index == total_chunks - 1
555
+ })
556
+
557
+ # Small delay to prevent overwhelming the client
558
+ await asyncio.sleep(0.01)
559
+
560
+ log_info(
561
+ f"TTS streaming completed",
562
+ session_id=session.session.session_id,
563
+ text_length=len(text),
564
+ audio_size=len(audio_data)
565
+ )
566
+
567
+ except asyncio.CancelledError:
568
+ log_info("TTS streaming cancelled", session_id=session.session.session_id)
569
+ raise
570
+ except Exception as e:
571
+ log_error(
572
+ f"TTS generation error",
573
+ error=str(e),
574
+ session_id=session.session.session_id
575
+ )
576
+ await websocket.send_json({
577
+ "type": "error",
578
+ "message": f"TTS error: {str(e)}"
579
+ })