ciyidogan commited on
Commit
e231b55
·
verified ·
1 Parent(s): 9b48788

Create websocket_handler.py

Browse files
Files changed (1) hide show
  1. websocket_handler.py +427 -0
websocket_handler.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket Handler for Real-time STT/TTS
3
+ """
4
+ from fastapi import WebSocket, WebSocketDisconnect, HTTPException
5
+ from typing import Dict, Any, Optional
6
+ import json
7
+ import asyncio
8
+ import base64
9
+ from datetime import datetime
10
+ import sys
11
+ import numpy as np
12
+ from enum import Enum
13
+
14
+ from session import Session, session_store
15
+ from config_provider import ConfigProvider
16
+ from chat_handler import _handle_new_message, _handle_parameter_followup
17
+ from stt_factory import STTFactory
18
+ from tts_factory import TTSFactory
19
+ from utils import log
20
+
21
+ # ========================= CONSTANTS =========================
22
+ SILENCE_THRESHOLD_MS = 2000
23
+ AUDIO_CHUNK_SIZE = 4096
24
+ ENERGY_THRESHOLD = 0.01
25
+
26
+ # ========================= ENUMS =========================
27
+ class ConversationState(Enum):
28
+ IDLE = "idle"
29
+ LISTENING = "listening"
30
+ PROCESSING_STT = "processing_stt"
31
+ PROCESSING_LLM = "processing_llm"
32
+ PROCESSING_TTS = "processing_tts"
33
+ PLAYING_AUDIO = "playing_audio"
34
+
35
+ # ========================= CLASSES =========================
36
+ class AudioBuffer:
37
+ """Buffer for accumulating audio chunks"""
38
+ def __init__(self):
39
+ self.chunks = []
40
+ self.total_size = 0
41
+
42
+ def add_chunk(self, chunk_data: str):
43
+ """Add base64 encoded audio chunk"""
44
+ decoded = base64.b64decode(chunk_data)
45
+ self.chunks.append(decoded)
46
+ self.total_size += len(decoded)
47
+
48
+ def get_audio(self) -> bytes:
49
+ """Get concatenated audio data"""
50
+ return b''.join(self.chunks)
51
+
52
+ def clear(self):
53
+ """Clear buffer"""
54
+ self.chunks.clear()
55
+ self.total_size = 0
56
+
57
+ class SilenceDetector:
58
+ """Detect silence in audio stream"""
59
+ def __init__(self, threshold_ms: int = SILENCE_THRESHOLD_MS, energy_threshold: float = ENERGY_THRESHOLD):
60
+ self.threshold_ms = threshold_ms
61
+ self.energy_threshold = energy_threshold
62
+ self.silence_start = None
63
+ self.sample_rate = 16000 # Default sample rate
64
+
65
+ def is_silence(self, audio_chunk: bytes) -> bool:
66
+ """Check if audio chunk is silence"""
67
+ try:
68
+ # Convert bytes to numpy array (assuming 16-bit PCM)
69
+ audio_data = np.frombuffer(audio_chunk, dtype=np.int16)
70
+
71
+ # Calculate RMS energy
72
+ rms = np.sqrt(np.mean(audio_data**2))
73
+ normalized_rms = rms / 32768.0 # Normalize for 16-bit audio
74
+
75
+ return normalized_rms < self.energy_threshold
76
+ except Exception as e:
77
+ log(f"⚠️ Silence detection error: {e}")
78
+ return False
79
+
80
+ def update(self, audio_chunk: bytes) -> Optional[int]:
81
+ """Update silence detection and return silence duration in ms"""
82
+ is_silent = self.is_silence(audio_chunk)
83
+
84
+ if is_silent:
85
+ if self.silence_start is None:
86
+ self.silence_start = datetime.now()
87
+ log("🔇 Silence started")
88
+ else:
89
+ silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000
90
+ return int(silence_duration)
91
+ else:
92
+ if self.silence_start is not None:
93
+ log("🔊 Speech detected, silence broken")
94
+ self.silence_start = None
95
+
96
+ return 0
97
+
98
+ class BargeInHandler:
99
+ """Handle barge-in (interruption) logic"""
100
+ def __init__(self):
101
+ self.interrupted_at_state: Optional[ConversationState] = None
102
+ self.accumulated_text: str = ""
103
+ self.pending_audio_chunks = []
104
+
105
+ def handle_interruption(self, current_state: ConversationState):
106
+ """Handle user interruption"""
107
+ self.interrupted_at_state = current_state
108
+ log(f"🛑 Barge-in detected at state: {current_state.value}")
109
+
110
+ def should_preserve_context(self) -> bool:
111
+ """Check if context should be preserved after interruption"""
112
+ # Preserve context if interrupted during LLM or TTS processing
113
+ return self.interrupted_at_state in [
114
+ ConversationState.PROCESSING_LLM,
115
+ ConversationState.PROCESSING_TTS,
116
+ ConversationState.PLAYING_AUDIO
117
+ ]
118
+
119
+ class ConversationManager:
120
+ """Manage conversation state and flow"""
121
+ def __init__(self, session: Session):
122
+ self.session = session
123
+ self.state = ConversationState.IDLE
124
+ self.audio_buffer = AudioBuffer()
125
+ self.silence_detector = SilenceDetector()
126
+ self.barge_in_handler = BargeInHandler()
127
+ self.stt_manager = None
128
+ self.current_transcription = ""
129
+ self.is_streaming = False
130
+
131
+ async def initialize_stt(self):
132
+ """Initialize STT provider"""
133
+ try:
134
+ self.stt_manager = STTFactory.create_provider()
135
+ if self.stt_manager:
136
+ config = ConfigProvider.get().global_config.stt_settings
137
+ await self.stt_manager.start_streaming({
138
+ "language": config.get("language", "tr-TR"),
139
+ "interim_results": config.get("interim_results", True),
140
+ "single_utterance": False, # Important for continuous listening
141
+ "enable_punctuation": config.get("enable_punctuation", True)
142
+ })
143
+ log("✅ STT manager initialized")
144
+ return True
145
+ except Exception as e:
146
+ log(f"❌ Failed to initialize STT: {e}")
147
+ return False
148
+
149
+ def change_state(self, new_state: ConversationState):
150
+ """Change conversation state"""
151
+ old_state = self.state
152
+ self.state = new_state
153
+ log(f"📊 State change: {old_state.value} → {new_state.value}")
154
+
155
+ def handle_barge_in(self):
156
+ """Handle user interruption"""
157
+ self.barge_in_handler.handle_interruption(self.state)
158
+ self.change_state(ConversationState.LISTENING)
159
+
160
+ def reset_audio_buffer(self):
161
+ """Reset audio buffer for new utterance"""
162
+ self.audio_buffer.clear()
163
+ self.silence_detector.silence_start = None
164
+ self.current_transcription = ""
165
+
166
+ # ========================= WEBSOCKET HANDLER =========================
167
+ async def websocket_endpoint(websocket: WebSocket, session_id: str):
168
+ """Main WebSocket endpoint for real-time conversation"""
169
+ await websocket.accept()
170
+ log(f"🔌 WebSocket connected for session: {session_id}")
171
+
172
+ # Get session
173
+ session = session_store.get_session(session_id)
174
+ if not session:
175
+ await websocket.send_json({
176
+ "type": "error",
177
+ "message": "Session not found"
178
+ })
179
+ await websocket.close()
180
+ return
181
+
182
+ # Initialize conversation manager
183
+ conversation = ConversationManager(session)
184
+
185
+ # Initialize STT
186
+ stt_initialized = await conversation.initialize_stt()
187
+ if not stt_initialized:
188
+ await websocket.send_json({
189
+ "type": "error",
190
+ "message": "STT initialization failed"
191
+ })
192
+
193
+ try:
194
+ while True:
195
+ # Receive message
196
+ message = await websocket.receive_json()
197
+ message_type = message.get("type")
198
+
199
+ if message_type == "audio_chunk":
200
+ await handle_audio_chunk(websocket, conversation, message)
201
+
202
+ elif message_type == "control":
203
+ await handle_control_message(websocket, conversation, message)
204
+
205
+ elif message_type == "ping":
206
+ # Keep-alive ping
207
+ await websocket.send_json({"type": "pong"})
208
+
209
+ except WebSocketDisconnect:
210
+ log(f"🔌 WebSocket disconnected for session: {session_id}")
211
+ await cleanup_conversation(conversation)
212
+ except Exception as e:
213
+ log(f"❌ WebSocket error: {e}")
214
+ await websocket.send_json({
215
+ "type": "error",
216
+ "message": str(e)
217
+ })
218
+ await cleanup_conversation(conversation)
219
+
220
+ # ========================= MESSAGE HANDLERS =========================
221
+ async def handle_audio_chunk(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]):
222
+ """Handle incoming audio chunk"""
223
+ try:
224
+ audio_data = message.get("data")
225
+ if not audio_data:
226
+ return
227
+
228
+ # Check for barge-in
229
+ if conversation.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]:
230
+ conversation.handle_barge_in()
231
+ await websocket.send_json({
232
+ "type": "control",
233
+ "action": "stop_playback"
234
+ })
235
+
236
+ # Change state to listening if idle
237
+ if conversation.state == ConversationState.IDLE:
238
+ conversation.change_state(ConversationState.LISTENING)
239
+ await websocket.send_json({
240
+ "type": "state_change",
241
+ "from": "idle",
242
+ "to": "listening"
243
+ })
244
+
245
+ # Add to buffer
246
+ conversation.audio_buffer.add_chunk(audio_data)
247
+
248
+ # Decode for processing
249
+ decoded_audio = base64.b64decode(audio_data)
250
+
251
+ # Check silence
252
+ silence_duration = conversation.silence_detector.update(decoded_audio)
253
+
254
+ # Stream to STT if available
255
+ if conversation.stt_manager and conversation.state == ConversationState.LISTENING:
256
+ async for result in conversation.stt_manager.stream_audio(decoded_audio):
257
+ # Send interim results
258
+ await websocket.send_json({
259
+ "type": "transcription",
260
+ "text": result.text,
261
+ "is_final": result.is_final,
262
+ "confidence": result.confidence
263
+ })
264
+
265
+ if result.is_final:
266
+ conversation.current_transcription = result.text
267
+
268
+ # Check if user stopped speaking (2 seconds of silence)
269
+ if silence_duration > SILENCE_THRESHOLD_MS and conversation.current_transcription:
270
+ log(f"🔇 User stopped speaking after {silence_duration}ms of silence")
271
+ await process_user_input(websocket, conversation)
272
+
273
+ except Exception as e:
274
+ log(f"❌ Audio chunk handling error: {e}")
275
+ await websocket.send_json({
276
+ "type": "error",
277
+ "message": f"Audio processing error: {str(e)}"
278
+ })
279
+
280
+ async def handle_control_message(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]):
281
+ """Handle control messages"""
282
+ action = message.get("action")
283
+
284
+ if action == "start_session":
285
+ # Session already started
286
+ await websocket.send_json({
287
+ "type": "session_started",
288
+ "session_id": conversation.session.session_id
289
+ })
290
+
291
+ elif action == "end_session":
292
+ # Clean up and close
293
+ await cleanup_conversation(conversation)
294
+ await websocket.close()
295
+
296
+ elif action == "interrupt":
297
+ # Handle explicit interrupt
298
+ conversation.handle_barge_in()
299
+ await websocket.send_json({
300
+ "type": "control",
301
+ "action": "interrupt_acknowledged"
302
+ })
303
+
304
+ elif action == "reset":
305
+ # Reset conversation state
306
+ conversation.reset_audio_buffer()
307
+ conversation.change_state(ConversationState.IDLE)
308
+ await websocket.send_json({
309
+ "type": "state_change",
310
+ "from": conversation.state.value,
311
+ "to": "idle"
312
+ })
313
+
314
+ # ========================= PROCESSING FUNCTIONS =========================
315
+ async def process_user_input(websocket: WebSocket, conversation: ConversationManager):
316
+ """Process complete user input"""
317
+ try:
318
+ user_text = conversation.current_transcription
319
+ if not user_text:
320
+ conversation.reset_audio_buffer()
321
+ conversation.change_state(ConversationState.IDLE)
322
+ return
323
+
324
+ log(f"💬 Processing user input: {user_text}")
325
+
326
+ # Change state to processing
327
+ conversation.change_state(ConversationState.PROCESSING_STT)
328
+ await websocket.send_json({
329
+ "type": "state_change",
330
+ "from": "listening",
331
+ "to": "processing_stt"
332
+ })
333
+
334
+ # Send final transcription
335
+ await websocket.send_json({
336
+ "type": "transcription",
337
+ "text": user_text,
338
+ "is_final": True,
339
+ "confidence": 0.95
340
+ })
341
+
342
+ # Process with LLM
343
+ conversation.change_state(ConversationState.PROCESSING_LLM)
344
+ await websocket.send_json({
345
+ "type": "state_change",
346
+ "from": "processing_stt",
347
+ "to": "processing_llm"
348
+ })
349
+
350
+ # Add to session history
351
+ conversation.session.add_turn("user", user_text)
352
+
353
+ # Get response based on session state
354
+ if conversation.session.state == "await_param":
355
+ response_text = await _handle_parameter_followup(conversation.session, user_text)
356
+ else:
357
+ response_text = await _handle_new_message(conversation.session, user_text)
358
+
359
+ # Add response to history
360
+ conversation.session.add_turn("assistant", response_text)
361
+
362
+ # Send text response
363
+ await websocket.send_json({
364
+ "type": "assistant_response",
365
+ "text": response_text
366
+ })
367
+
368
+ # Generate TTS if enabled
369
+ tts_provider = TTSFactory.create_provider()
370
+ if tts_provider:
371
+ conversation.change_state(ConversationState.PROCESSING_TTS)
372
+ await websocket.send_json({
373
+ "type": "state_change",
374
+ "from": "processing_llm",
375
+ "to": "processing_tts"
376
+ })
377
+
378
+ # Generate audio
379
+ audio_data = await tts_provider.synthesize(response_text)
380
+
381
+ # Send audio in chunks
382
+ chunk_size = 4096
383
+ for i in range(0, len(audio_data), chunk_size):
384
+ chunk = audio_data[i:i + chunk_size]
385
+ await websocket.send_json({
386
+ "type": "tts_audio",
387
+ "data": base64.b64encode(chunk).decode('utf-8'),
388
+ "chunk_index": i // chunk_size,
389
+ "is_last": i + chunk_size >= len(audio_data)
390
+ })
391
+
392
+ conversation.change_state(ConversationState.PLAYING_AUDIO)
393
+ await websocket.send_json({
394
+ "type": "state_change",
395
+ "from": "processing_tts",
396
+ "to": "playing_audio"
397
+ })
398
+ else:
399
+ # No TTS, go back to idle
400
+ conversation.change_state(ConversationState.IDLE)
401
+ await websocket.send_json({
402
+ "type": "state_change",
403
+ "from": "processing_llm",
404
+ "to": "idle"
405
+ })
406
+
407
+ # Reset for next input
408
+ conversation.reset_audio_buffer()
409
+
410
+ except Exception as e:
411
+ log(f"❌ Error processing user input: {e}")
412
+ await websocket.send_json({
413
+ "type": "error",
414
+ "message": f"Processing error: {str(e)}"
415
+ })
416
+ conversation.reset_audio_buffer()
417
+ conversation.change_state(ConversationState.IDLE)
418
+
419
+ # ========================= CLEANUP =========================
420
+ async def cleanup_conversation(conversation: ConversationManager):
421
+ """Clean up conversation resources"""
422
+ try:
423
+ if conversation.stt_manager:
424
+ await conversation.stt_manager.stop_streaming()
425
+ log(f"🧹 Cleaned up conversation for session: {conversation.session.session_id}")
426
+ except Exception as e:
427
+ log(f"⚠️ Cleanup error: {e}")