ciyidogan commited on
Commit
282b8e9
·
verified ·
1 Parent(s): 1784117

Delete websocket_manager.py

Browse files
Files changed (1) hide show
  1. websocket_manager.py +0 -523
websocket_manager.py DELETED
@@ -1,523 +0,0 @@
1
- """
2
- WebSocket Manager for Flare
3
- ===========================
4
- Manages WebSocket connections and message routing
5
- """
6
- import base64
7
- import struct
8
- import asyncio
9
- from typing import Dict, Optional, Set
10
- from fastapi import WebSocket, WebSocketDisconnect
11
- import json
12
- from datetime import datetime
13
- import traceback
14
-
15
- from event_bus import EventBus, Event, EventType
16
- from utils.logger import log_info, log_error, log_debug, log_warning
17
-
18
-
19
- class WebSocketConnection:
20
- """Wrapper for WebSocket connection with metadata"""
21
-
22
- def __init__(self, websocket: WebSocket, session_id: str):
23
- self.websocket = websocket
24
- self.session_id = session_id
25
- self.connected_at = datetime.utcnow()
26
- self.last_activity = datetime.utcnow()
27
- self.is_active = True
28
-
29
- async def send_json(self, data: dict):
30
- """Send JSON data to client"""
31
- try:
32
- if self.is_active:
33
- await self.websocket.send_json(data)
34
- self.last_activity = datetime.utcnow()
35
- except Exception as e:
36
- log_error(
37
- f"❌ Failed to send message",
38
- session_id=self.session_id,
39
- error=str(e)
40
- )
41
- self.is_active = False
42
- raise
43
-
44
- async def receive_json(self) -> dict:
45
- """Receive JSON data from client"""
46
- try:
47
- data = await self.websocket.receive_json()
48
- self.last_activity = datetime.utcnow()
49
- return data
50
- except WebSocketDisconnect:
51
- self.is_active = False
52
- raise
53
- except Exception as e:
54
- log_error(
55
- f"❌ Failed to receive message",
56
- session_id=self.session_id,
57
- error=str(e)
58
- )
59
- self.is_active = False
60
- raise
61
-
62
- async def close(self):
63
- """Close the connection"""
64
- try:
65
- self.is_active = False
66
- await self.websocket.close()
67
- except:
68
- pass
69
-
70
-
71
- class WebSocketManager:
72
- """Manages WebSocket connections and routing"""
73
-
74
- def __init__(self, event_bus: EventBus):
75
- self.event_bus = event_bus
76
- self.connections: Dict[str, WebSocketConnection] = {}
77
- self.message_queues: Dict[str, asyncio.Queue] = {}
78
- self._setup_event_handlers()
79
-
80
- def _setup_event_handlers(self):
81
- """Subscribe to events that need to be sent to clients"""
82
- # State events
83
- self.event_bus.subscribe(EventType.STATE_TRANSITION, self._handle_state_transition)
84
-
85
- # STT events
86
- self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready)
87
- self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result)
88
-
89
- # TTS events
90
- self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_started)
91
- self.event_bus.subscribe(EventType.TTS_CHUNK_READY, self._handle_tts_chunk)
92
- self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed)
93
-
94
- # LLM events
95
- self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response)
96
-
97
- # Error events
98
- self.event_bus.subscribe(EventType.RECOVERABLE_ERROR, self._handle_error)
99
- self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_error)
100
-
101
- async def connect(self, websocket: WebSocket, session_id: str):
102
- """Accept new WebSocket connection"""
103
- await websocket.accept()
104
-
105
- # Check for existing connection
106
- if session_id in self.connections:
107
- log_warning(
108
- f"⚠️ Existing connection for session, closing old one",
109
- session_id=session_id
110
- )
111
- await self.disconnect(session_id)
112
-
113
- # Create connection wrapper
114
- connection = WebSocketConnection(websocket, session_id)
115
- self.connections[session_id] = connection
116
-
117
- # Create message queue
118
- self.message_queues[session_id] = asyncio.Queue()
119
-
120
- log_info(
121
- f"✅ WebSocket connected",
122
- session_id=session_id,
123
- total_connections=len(self.connections)
124
- )
125
-
126
- # Publish connection event
127
- await self.event_bus.publish(Event(
128
- type=EventType.WEBSOCKET_CONNECTED,
129
- session_id=session_id,
130
- data={}
131
- ))
132
-
133
- async def disconnect(self, session_id: str):
134
- """Disconnect WebSocket connection"""
135
- connection = self.connections.get(session_id)
136
- if connection:
137
- await connection.close()
138
- del self.connections[session_id]
139
-
140
- # Remove message queue
141
- if session_id in self.message_queues:
142
- del self.message_queues[session_id]
143
-
144
- log_info(
145
- f"🔌 WebSocket disconnected",
146
- session_id=session_id,
147
- total_connections=len(self.connections)
148
- )
149
-
150
- # Publish disconnection event
151
- await self.event_bus.publish(Event(
152
- type=EventType.WEBSOCKET_DISCONNECTED,
153
- session_id=session_id,
154
- data={}
155
- ))
156
-
157
- async def handle_connection(self, websocket: WebSocket, session_id: str):
158
- """Handle WebSocket connection lifecycle"""
159
- try:
160
- # Connect
161
- await self.connect(websocket, session_id)
162
-
163
- # Create tasks for bidirectional communication
164
- receive_task = asyncio.create_task(self._receive_messages(session_id))
165
- send_task = asyncio.create_task(self._send_messages(session_id))
166
-
167
- # Wait for either task to complete
168
- done, pending = await asyncio.wait(
169
- [receive_task, send_task],
170
- return_when=asyncio.FIRST_COMPLETED
171
- )
172
-
173
- # Cancel pending tasks
174
- for task in pending:
175
- task.cancel()
176
- try:
177
- await task
178
- except asyncio.CancelledError:
179
- pass
180
-
181
- except WebSocketDisconnect:
182
- log_info(f"WebSocket disconnected normally", session_id=session_id)
183
- except Exception as e:
184
- log_error(
185
- f"❌ WebSocket error",
186
- session_id=session_id,
187
- error=str(e),
188
- traceback=traceback.format_exc()
189
- )
190
-
191
- # Publish error event
192
- await self.event_bus.publish(Event(
193
- type=EventType.WEBSOCKET_ERROR,
194
- session_id=session_id,
195
- data={
196
- "error_type": "websocket_error",
197
- "message": str(e)
198
- }
199
- ))
200
- finally:
201
- # Ensure disconnection
202
- await self.disconnect(session_id)
203
-
204
- async def _receive_messages(self, session_id: str):
205
- """Receive messages from client"""
206
- connection = self.connections.get(session_id)
207
- if not connection:
208
- return
209
-
210
- try:
211
- while connection.is_active:
212
- # Receive message
213
- message = await connection.receive_json()
214
-
215
- log_debug(
216
- f"📨 Received message",
217
- session_id=session_id,
218
- message_type=message.get("type")
219
- )
220
-
221
- # Route message based on type
222
- await self._route_client_message(session_id, message)
223
-
224
- except WebSocketDisconnect:
225
- log_info(f"Client disconnected", session_id=session_id)
226
- except Exception as e:
227
- log_error(
228
- f"❌ Error receiving messages",
229
- session_id=session_id,
230
- error=str(e)
231
- )
232
- raise
233
-
234
- async def _send_messages(self, session_id: str):
235
- """Send queued messages to client"""
236
- connection = self.connections.get(session_id)
237
- queue = self.message_queues.get(session_id)
238
-
239
- if not connection or not queue:
240
- return
241
-
242
- try:
243
- while connection.is_active:
244
- # Wait for message with timeout
245
- try:
246
- message = await asyncio.wait_for(queue.get(), timeout=30.0)
247
-
248
- # Send to client
249
- await connection.send_json(message)
250
-
251
- log_debug(
252
- f"📤 Sent message",
253
- session_id=session_id,
254
- message_type=message.get("type")
255
- )
256
-
257
- except asyncio.TimeoutError:
258
- # Send ping to keep connection alive
259
- await connection.send_json({"type": "ping"})
260
-
261
- except Exception as e:
262
- log_error(
263
- f"❌ Error sending messages",
264
- session_id=session_id,
265
- error=str(e)
266
- )
267
- raise
268
-
269
- async def _route_client_message(self, session_id: str, message: dict):
270
- """Route message from client to appropriate handler"""
271
- message_type = message.get("type")
272
-
273
- if message_type == "audio_chunk":
274
- # Audio data from client
275
- audio_data_base64 = message.get("data")
276
-
277
- if audio_data_base64:
278
- # Debug için audio analizi
279
- try:
280
- import base64
281
- import struct
282
-
283
- # Base64'ten binary'ye çevir
284
- audio_data = base64.b64decode(audio_data_base64)
285
-
286
- # Session için debug counter
287
- if not hasattr(self, 'audio_debug_counters'):
288
- self.audio_debug_counters = {}
289
-
290
- if session_id not in self.audio_debug_counters:
291
- self.audio_debug_counters[session_id] = 0
292
-
293
- # İlk 5 chunk için detaylı log
294
- if self.audio_debug_counters[session_id] < 5:
295
- log_info(f"🔊 Audio chunk analysis #{self.audio_debug_counters[session_id]}",
296
- session_id=session_id,
297
- size_bytes=len(audio_data),
298
- base64_size=len(audio_data_base64))
299
-
300
- # İlk 20 byte'ı hex olarak göster
301
- if len(audio_data) >= 20:
302
- log_debug(f" First 20 bytes (hex): {audio_data[:20].hex()}")
303
-
304
- # Linear16 (little-endian int16) olarak yorumla
305
- samples = struct.unpack('<10h', audio_data[:20])
306
- log_debug(f" First 10 samples: {samples}")
307
- log_debug(f" Max amplitude (first 10): {max(abs(s) for s in samples)}")
308
-
309
- # Tüm chunk'ı analiz et
310
- total_samples = len(audio_data) // 2
311
- if total_samples > 0:
312
- all_samples = struct.unpack(f'<{total_samples}h', audio_data[:total_samples*2])
313
- max_amp = max(abs(s) for s in all_samples)
314
- avg_amp = sum(abs(s) for s in all_samples) / total_samples
315
-
316
- # Sessizlik kontrolü
317
- silent = max_amp < 100 # Linear16 için düşük eşik
318
-
319
- log_info(f" Audio stats - Max: {max_amp}, Avg: {avg_amp:.1f}, Silent: {silent}")
320
-
321
- # Eğer çok sessizse uyar
322
- if max_amp < 50:
323
- log_warning(f"⚠️ Very low audio level detected! Max amplitude: {max_amp}")
324
-
325
- self.audio_debug_counters[session_id] += 1
326
-
327
- except Exception as e:
328
- log_error(f"Error analyzing audio chunk: {e}")
329
-
330
- # Audio data from client
331
- await self.event_bus.publish(Event(
332
- type=EventType.AUDIO_CHUNK_RECEIVED,
333
- session_id=session_id,
334
- data={
335
- "audio_data": message.get("data"),
336
- "timestamp": message.get("timestamp")
337
- }
338
- ))
339
-
340
- elif message_type == "control":
341
- # Control messages
342
- action = message.get("action")
343
- config = message.get("config", {})
344
-
345
- if action == "start_conversation":
346
- # Yeni action: Mevcut session için conversation başlat
347
- log_info(f"🎤 Starting conversation for session | session_id={session_id}")
348
-
349
- await self.event_bus.publish(Event(
350
- type=EventType.CONVERSATION_STARTED,
351
- session_id=session_id,
352
- data={
353
- "config": config,
354
- "continuous_listening": config.get("continuous_listening", True)
355
- }
356
- ))
357
-
358
- # Send confirmation to client
359
- await self.send_message(session_id, {
360
- "type": "conversation_started",
361
- "message": "Conversation started successfully"
362
- })
363
-
364
- elif action == "stop_conversation":
365
- await self.event_bus.publish(Event(
366
- type=EventType.CONVERSATION_ENDED,
367
- session_id=session_id,
368
- data={"reason": "user_request"}
369
- ))
370
-
371
- elif action == "start_session":
372
- # Bu artık kullanılmamalı
373
- log_warning(f"⚠️ Deprecated start_session action received | session_id={session_id}")
374
-
375
- # Yine de işle ama conversation_started olarak
376
- await self.event_bus.publish(Event(
377
- type=EventType.CONVERSATION_STARTED,
378
- session_id=session_id,
379
- data=config
380
- ))
381
-
382
- elif action == "stop_session":
383
- await self.event_bus.publish(Event(
384
- type=EventType.CONVERSATION_ENDED,
385
- session_id=session_id,
386
- data={"reason": "user_request"}
387
- ))
388
-
389
- elif action == "end_session":
390
- await self.event_bus.publish(Event(
391
- type=EventType.SESSION_ENDED,
392
- session_id=session_id,
393
- data={"reason": "user_request"}
394
- ))
395
-
396
- elif action == "audio_ended":
397
- await self.event_bus.publish(Event(
398
- type=EventType.AUDIO_PLAYBACK_COMPLETED,
399
- session_id=session_id,
400
- data={}
401
- ))
402
-
403
- else:
404
- log_warning(
405
- f"⚠️ Unknown control action",
406
- session_id=session_id,
407
- action=action
408
- )
409
-
410
- elif message_type == "ping":
411
- # Respond to ping
412
- await self.send_message(session_id, {"type": "pong"})
413
-
414
- else:
415
- log_warning(
416
- f"⚠️ Unknown message type",
417
- session_id=session_id,
418
- message_type=message_type
419
- )
420
-
421
- async def send_message(self, session_id: str, message: dict):
422
- """Queue message for sending to client"""
423
- queue = self.message_queues.get(session_id)
424
- if queue:
425
- await queue.put(message)
426
- else:
427
- log_warning(
428
- f"⚠️ No queue for session",
429
- session_id=session_id
430
- )
431
-
432
- async def broadcast_to_session(self, session_id: str, message: dict):
433
- """Send message immediately (bypass queue)"""
434
- connection = self.connections.get(session_id)
435
- if connection and connection.is_active:
436
- await connection.send_json(message)
437
-
438
- # Event handlers for sending messages to clients
439
-
440
- async def _handle_state_transition(self, event: Event):
441
- """Send state transition to client"""
442
- await self.send_message(event.session_id, {
443
- "type": "state_change",
444
- "from": event.data.get("old_state"),
445
- "to": event.data.get("new_state")
446
- })
447
-
448
- async def _handle_stt_ready(self, event: Event):
449
- """Send STT ready signal to client"""
450
- await self.send_message(event.session_id, {
451
- "type": "stt_ready",
452
- "message": "STT is ready to receive audio"
453
- })
454
-
455
- async def _handle_stt_result(self, event: Event):
456
- """Send STT result to client"""
457
- # Her türlü result'ı (interim + final) frontend'e gönder
458
- await self.send_message(event.session_id, {
459
- "type": "transcription",
460
- "text": event.data.get("text", ""),
461
- "is_final": event.data.get("is_final", False),
462
- "confidence": event.data.get("confidence", 0.0)
463
- })
464
-
465
- async def _handle_tts_started(self, event: Event):
466
- """Send assistant message when TTS starts"""
467
- if event.data.get("is_welcome"):
468
- # Send welcome message to client
469
- await self.send_message(event.session_id, {
470
- "type": "assistant_response",
471
- "text": event.data.get("text", ""),
472
- "is_welcome": True
473
- })
474
-
475
- async def _handle_tts_chunk(self, event: Event):
476
- """Send TTS audio chunk to client"""
477
- await self.send_message(event.session_id, {
478
- "type": "tts_audio",
479
- "data": event.data.get("audio_data"),
480
- "chunk_index": event.data.get("chunk_index"),
481
- "total_chunks": event.data.get("total_chunks"),
482
- "is_last": event.data.get("is_last", False),
483
- "mime_type": event.data.get("mime_type", "audio/mpeg")
484
- })
485
-
486
- async def _handle_tts_completed(self, event: Event):
487
- """Notify client that TTS is complete"""
488
- # Client knows from is_last flag in chunks
489
- pass
490
-
491
- async def _handle_llm_response(self, event: Event):
492
- """Send LLM response to client"""
493
- await self.send_message(event.session_id, {
494
- "type": "assistant_response",
495
- "text": event.data.get("text", ""),
496
- "is_welcome": event.data.get("is_welcome", False)
497
- })
498
-
499
- async def _handle_error(self, event: Event):
500
- """Send error to client"""
501
- error_type = event.data.get("error_type", "unknown")
502
- message = event.data.get("message", "An error occurred")
503
-
504
- await self.send_message(event.session_id, {
505
- "type": "error",
506
- "error_type": error_type,
507
- "message": message,
508
- "details": event.data.get("details", {})
509
- })
510
-
511
- def get_connection_count(self) -> int:
512
- """Get number of active connections"""
513
- return len(self.connections)
514
-
515
- def get_session_connections(self) -> Set[str]:
516
- """Get all active session IDs"""
517
- return set(self.connections.keys())
518
-
519
- async def close_all_connections(self):
520
- """Close all active connections"""
521
- session_ids = list(self.connections.keys())
522
- for session_id in session_ids:
523
- await self.disconnect(session_id)