ciyidogan commited on
Commit
a5c508e
·
verified ·
1 Parent(s): b488f18

Delete websocket_handler.py

Browse files
Files changed (1) hide show
  1. websocket_handler.py +0 -427
websocket_handler.py DELETED
@@ -1,427 +0,0 @@
1
- """
2
- WebSocket Handler for Real-time STT/TTS
3
- """
4
- from fastapi import WebSocket, WebSocketDisconnect, HTTPException
5
- from typing import Dict, Any, Optional
6
- import json
7
- import asyncio
8
- import base64
9
- from datetime import datetime
10
- import sys
11
- import numpy as np
12
- from enum import Enum
13
-
14
- from session import Session, session_store
15
- from config_provider import ConfigProvider
16
- from chat_handler import handle_new_message, handle_parameter_followup
17
- from stt_factory import STTFactory
18
- from tts_factory import TTSFactory
19
- from utils import log
20
-
21
- # ========================= CONSTANTS =========================
22
- SILENCE_THRESHOLD_MS = 2000
23
- AUDIO_CHUNK_SIZE = 4096
24
- ENERGY_THRESHOLD = 0.01
25
-
26
- # ========================= ENUMS =========================
27
- class ConversationState(Enum):
28
- IDLE = "idle"
29
- LISTENING = "listening"
30
- PROCESSING_STT = "processing_stt"
31
- PROCESSING_LLM = "processing_llm"
32
- PROCESSING_TTS = "processing_tts"
33
- PLAYING_AUDIO = "playing_audio"
34
-
35
- # ========================= CLASSES =========================
36
- class AudioBuffer:
37
- """Buffer for accumulating audio chunks"""
38
- def __init__(self):
39
- self.chunks = []
40
- self.total_size = 0
41
-
42
- def add_chunk(self, chunk_data: str):
43
- """Add base64 encoded audio chunk"""
44
- decoded = base64.b64decode(chunk_data)
45
- self.chunks.append(decoded)
46
- self.total_size += len(decoded)
47
-
48
- def get_audio(self) -> bytes:
49
- """Get concatenated audio data"""
50
- return b''.join(self.chunks)
51
-
52
- def clear(self):
53
- """Clear buffer"""
54
- self.chunks.clear()
55
- self.total_size = 0
56
-
57
- class SilenceDetector:
58
- """Detect silence in audio stream"""
59
- def __init__(self, threshold_ms: int = SILENCE_THRESHOLD_MS, energy_threshold: float = ENERGY_THRESHOLD):
60
- self.threshold_ms = threshold_ms
61
- self.energy_threshold = energy_threshold
62
- self.silence_start = None
63
- self.sample_rate = 16000 # Default sample rate
64
-
65
- def is_silence(self, audio_chunk: bytes) -> bool:
66
- """Check if audio chunk is silence"""
67
- try:
68
- # Convert bytes to numpy array (assuming 16-bit PCM)
69
- audio_data = np.frombuffer(audio_chunk, dtype=np.int16)
70
-
71
- # Calculate RMS energy
72
- rms = np.sqrt(np.mean(audio_data**2))
73
- normalized_rms = rms / 32768.0 # Normalize for 16-bit audio
74
-
75
- return normalized_rms < self.energy_threshold
76
- except Exception as e:
77
- log(f"⚠️ Silence detection error: {e}")
78
- return False
79
-
80
- def update(self, audio_chunk: bytes) -> Optional[int]:
81
- """Update silence detection and return silence duration in ms"""
82
- is_silent = self.is_silence(audio_chunk)
83
-
84
- if is_silent:
85
- if self.silence_start is None:
86
- self.silence_start = datetime.now()
87
- log("🔇 Silence started")
88
- else:
89
- silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000
90
- return int(silence_duration)
91
- else:
92
- if self.silence_start is not None:
93
- log("🔊 Speech detected, silence broken")
94
- self.silence_start = None
95
-
96
- return 0
97
-
98
- class BargeInHandler:
99
- """Handle barge-in (interruption) logic"""
100
- def __init__(self):
101
- self.interrupted_at_state: Optional[ConversationState] = None
102
- self.accumulated_text: str = ""
103
- self.pending_audio_chunks = []
104
-
105
- def handle_interruption(self, current_state: ConversationState):
106
- """Handle user interruption"""
107
- self.interrupted_at_state = current_state
108
- log(f"🛑 Barge-in detected at state: {current_state.value}")
109
-
110
- def should_preserve_context(self) -> bool:
111
- """Check if context should be preserved after interruption"""
112
- # Preserve context if interrupted during LLM or TTS processing
113
- return self.interrupted_at_state in [
114
- ConversationState.PROCESSING_LLM,
115
- ConversationState.PROCESSING_TTS,
116
- ConversationState.PLAYING_AUDIO
117
- ]
118
-
119
- class ConversationManager:
120
- """Manage conversation state and flow"""
121
- def __init__(self, session: Session):
122
- self.session = session
123
- self.state = ConversationState.IDLE
124
- self.audio_buffer = AudioBuffer()
125
- self.silence_detector = SilenceDetector()
126
- self.barge_in_handler = BargeInHandler()
127
- self.stt_manager = None
128
- self.current_transcription = ""
129
- self.is_streaming = False
130
-
131
- async def initialize_stt(self):
132
- """Initialize STT provider"""
133
- try:
134
- self.stt_manager = STTFactory.create_provider()
135
- if self.stt_manager:
136
- config = ConfigProvider.get().global_config.stt_settings
137
- await self.stt_manager.start_streaming({
138
- "language": config.get("language", "tr-TR"),
139
- "interim_results": config.get("interim_results", True),
140
- "single_utterance": False, # Important for continuous listening
141
- "enable_punctuation": config.get("enable_punctuation", True)
142
- })
143
- log("✅ STT manager initialized")
144
- return True
145
- except Exception as e:
146
- log(f"❌ Failed to initialize STT: {e}")
147
- return False
148
-
149
- def change_state(self, new_state: ConversationState):
150
- """Change conversation state"""
151
- old_state = self.state
152
- self.state = new_state
153
- log(f"📊 State change: {old_state.value} → {new_state.value}")
154
-
155
- def handle_barge_in(self):
156
- """Handle user interruption"""
157
- self.barge_in_handler.handle_interruption(self.state)
158
- self.change_state(ConversationState.LISTENING)
159
-
160
- def reset_audio_buffer(self):
161
- """Reset audio buffer for new utterance"""
162
- self.audio_buffer.clear()
163
- self.silence_detector.silence_start = None
164
- self.current_transcription = ""
165
-
166
- # ========================= WEBSOCKET HANDLER =========================
167
- async def websocket_endpoint(websocket: WebSocket, session_id: str):
168
- """Main WebSocket endpoint for real-time conversation"""
169
- await websocket.accept()
170
- log(f"🔌 WebSocket connected for session: {session_id}")
171
-
172
- # Get session
173
- session = session_store.get_session(session_id)
174
- if not session:
175
- await websocket.send_json({
176
- "type": "error",
177
- "message": "Session not found"
178
- })
179
- await websocket.close()
180
- return
181
-
182
- # Initialize conversation manager
183
- conversation = ConversationManager(session)
184
-
185
- # Initialize STT
186
- stt_initialized = await conversation.initialize_stt()
187
- if not stt_initialized:
188
- await websocket.send_json({
189
- "type": "error",
190
- "message": "STT initialization failed"
191
- })
192
-
193
- try:
194
- while True:
195
- # Receive message
196
- message = await websocket.receive_json()
197
- message_type = message.get("type")
198
-
199
- if message_type == "audio_chunk":
200
- await handle_audio_chunk(websocket, conversation, message)
201
-
202
- elif message_type == "control":
203
- await handle_control_message(websocket, conversation, message)
204
-
205
- elif message_type == "ping":
206
- # Keep-alive ping
207
- await websocket.send_json({"type": "pong"})
208
-
209
- except WebSocketDisconnect:
210
- log(f"🔌 WebSocket disconnected for session: {session_id}")
211
- await cleanup_conversation(conversation)
212
- except Exception as e:
213
- log(f"❌ WebSocket error: {e}")
214
- await websocket.send_json({
215
- "type": "error",
216
- "message": str(e)
217
- })
218
- await cleanup_conversation(conversation)
219
-
220
- # ========================= MESSAGE HANDLERS =========================
221
- async def handle_audio_chunk(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]):
222
- """Handle incoming audio chunk"""
223
- try:
224
- audio_data = message.get("data")
225
- if not audio_data:
226
- return
227
-
228
- # Check for barge-in
229
- if conversation.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]:
230
- conversation.handle_barge_in()
231
- await websocket.send_json({
232
- "type": "control",
233
- "action": "stop_playback"
234
- })
235
-
236
- # Change state to listening if idle
237
- if conversation.state == ConversationState.IDLE:
238
- conversation.change_state(ConversationState.LISTENING)
239
- await websocket.send_json({
240
- "type": "state_change",
241
- "from": "idle",
242
- "to": "listening"
243
- })
244
-
245
- # Add to buffer
246
- conversation.audio_buffer.add_chunk(audio_data)
247
-
248
- # Decode for processing
249
- decoded_audio = base64.b64decode(audio_data)
250
-
251
- # Check silence
252
- silence_duration = conversation.silence_detector.update(decoded_audio)
253
-
254
- # Stream to STT if available
255
- if conversation.stt_manager and conversation.state == ConversationState.LISTENING:
256
- async for result in conversation.stt_manager.stream_audio(decoded_audio):
257
- # Send interim results
258
- await websocket.send_json({
259
- "type": "transcription",
260
- "text": result.text,
261
- "is_final": result.is_final,
262
- "confidence": result.confidence
263
- })
264
-
265
- if result.is_final:
266
- conversation.current_transcription = result.text
267
-
268
- # Check if user stopped speaking (2 seconds of silence)
269
- if silence_duration > SILENCE_THRESHOLD_MS and conversation.current_transcription:
270
- log(f"🔇 User stopped speaking after {silence_duration}ms of silence")
271
- await process_user_input(websocket, conversation)
272
-
273
- except Exception as e:
274
- log(f"❌ Audio chunk handling error: {e}")
275
- await websocket.send_json({
276
- "type": "error",
277
- "message": f"Audio processing error: {str(e)}"
278
- })
279
-
280
- async def handle_control_message(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]):
281
- """Handle control messages"""
282
- action = message.get("action")
283
-
284
- if action == "start_session":
285
- # Session already started
286
- await websocket.send_json({
287
- "type": "session_started",
288
- "session_id": conversation.session.session_id
289
- })
290
-
291
- elif action == "end_session":
292
- # Clean up and close
293
- await cleanup_conversation(conversation)
294
- await websocket.close()
295
-
296
- elif action == "interrupt":
297
- # Handle explicit interrupt
298
- conversation.handle_barge_in()
299
- await websocket.send_json({
300
- "type": "control",
301
- "action": "interrupt_acknowledged"
302
- })
303
-
304
- elif action == "reset":
305
- # Reset conversation state
306
- conversation.reset_audio_buffer()
307
- conversation.change_state(ConversationState.IDLE)
308
- await websocket.send_json({
309
- "type": "state_change",
310
- "from": conversation.state.value,
311
- "to": "idle"
312
- })
313
-
314
- # ========================= PROCESSING FUNCTIONS =========================
315
- async def process_user_input(websocket: WebSocket, conversation: ConversationManager):
316
- """Process complete user input"""
317
- try:
318
- user_text = conversation.current_transcription
319
- if not user_text:
320
- conversation.reset_audio_buffer()
321
- conversation.change_state(ConversationState.IDLE)
322
- return
323
-
324
- log(f"💬 Processing user input: {user_text}")
325
-
326
- # Change state to processing
327
- conversation.change_state(ConversationState.PROCESSING_STT)
328
- await websocket.send_json({
329
- "type": "state_change",
330
- "from": "listening",
331
- "to": "processing_stt"
332
- })
333
-
334
- # Send final transcription
335
- await websocket.send_json({
336
- "type": "transcription",
337
- "text": user_text,
338
- "is_final": True,
339
- "confidence": 0.95
340
- })
341
-
342
- # Process with LLM
343
- conversation.change_state(ConversationState.PROCESSING_LLM)
344
- await websocket.send_json({
345
- "type": "state_change",
346
- "from": "processing_stt",
347
- "to": "processing_llm"
348
- })
349
-
350
- # Add to session history
351
- conversation.session.add_turn("user", user_text)
352
-
353
- # Get response based on session state
354
- if conversation.session.state == "await_param":
355
- response_text = await handle_parameter_followup(conversation.session, user_text)
356
- else:
357
- response_text = await handle_new_message(conversation.session, user_text)
358
-
359
- # Add response to history
360
- conversation.session.add_turn("assistant", response_text)
361
-
362
- # Send text response
363
- await websocket.send_json({
364
- "type": "assistant_response",
365
- "text": response_text
366
- })
367
-
368
- # Generate TTS if enabled
369
- tts_provider = TTSFactory.create_provider()
370
- if tts_provider:
371
- conversation.change_state(ConversationState.PROCESSING_TTS)
372
- await websocket.send_json({
373
- "type": "state_change",
374
- "from": "processing_llm",
375
- "to": "processing_tts"
376
- })
377
-
378
- # Generate audio
379
- audio_data = await tts_provider.synthesize(response_text)
380
-
381
- # Send audio in chunks
382
- chunk_size = 4096
383
- for i in range(0, len(audio_data), chunk_size):
384
- chunk = audio_data[i:i + chunk_size]
385
- await websocket.send_json({
386
- "type": "tts_audio",
387
- "data": base64.b64encode(chunk).decode('utf-8'),
388
- "chunk_index": i // chunk_size,
389
- "is_last": i + chunk_size >= len(audio_data)
390
- })
391
-
392
- conversation.change_state(ConversationState.PLAYING_AUDIO)
393
- await websocket.send_json({
394
- "type": "state_change",
395
- "from": "processing_tts",
396
- "to": "playing_audio"
397
- })
398
- else:
399
- # No TTS, go back to idle
400
- conversation.change_state(ConversationState.IDLE)
401
- await websocket.send_json({
402
- "type": "state_change",
403
- "from": "processing_llm",
404
- "to": "idle"
405
- })
406
-
407
- # Reset for next input
408
- conversation.reset_audio_buffer()
409
-
410
- except Exception as e:
411
- log(f"❌ Error processing user input: {e}")
412
- await websocket.send_json({
413
- "type": "error",
414
- "message": f"Processing error: {str(e)}"
415
- })
416
- conversation.reset_audio_buffer()
417
- conversation.change_state(ConversationState.IDLE)
418
-
419
- # ========================= CLEANUP =========================
420
- async def cleanup_conversation(conversation: ConversationManager):
421
- """Clean up conversation resources"""
422
- try:
423
- if conversation.stt_manager:
424
- await conversation.stt_manager.stop_streaming()
425
- log(f"🧹 Cleaned up conversation for session: {conversation.session.session_id}")
426
- except Exception as e:
427
- log(f"⚠️ Cleanup error: {e}")