Spaces:
Building
Building
""" | |
WebSocket Manager for Flare | |
=========================== | |
Manages WebSocket connections and message routing | |
""" | |
import base64 | |
import struct | |
import asyncio | |
from typing import Dict, Optional, Set | |
from fastapi import WebSocket, WebSocketDisconnect | |
import json | |
from datetime import datetime | |
import traceback | |
from .event_bus import EventBus, Event, EventType | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class WebSocketConnection: | |
"""Wrapper for WebSocket connection with metadata""" | |
def __init__(self, websocket: WebSocket, session_id: str): | |
self.websocket = websocket | |
self.session_id = session_id | |
self.connected_at = datetime.utcnow() | |
self.last_activity = datetime.utcnow() | |
self.is_active = True | |
async def send_json(self, data: dict): | |
"""Send JSON data to client""" | |
try: | |
if self.is_active: | |
await self.websocket.send_json(data) | |
self.last_activity = datetime.utcnow() | |
except Exception as e: | |
log_error( | |
f"β Failed to send message", | |
session_id=self.session_id, | |
error=str(e) | |
) | |
self.is_active = False | |
raise | |
async def receive_json(self) -> dict: | |
"""Receive JSON data from client""" | |
try: | |
data = await self.websocket.receive_json() | |
self.last_activity = datetime.utcnow() | |
return data | |
except WebSocketDisconnect: | |
self.is_active = False | |
raise | |
except Exception as e: | |
log_error( | |
f"β Failed to receive message", | |
session_id=self.session_id, | |
error=str(e) | |
) | |
self.is_active = False | |
raise | |
async def close(self): | |
"""Close the connection""" | |
try: | |
self.is_active = False | |
await self.websocket.close() | |
except: | |
pass | |
class WebSocketManager: | |
"""Manages WebSocket connections and routing""" | |
def __init__(self, event_bus: EventBus): | |
self.event_bus = event_bus | |
self.connections: Dict[str, WebSocketConnection] = {} | |
self.message_queues: Dict[str, asyncio.Queue] = {} | |
self._setup_event_handlers() | |
def _setup_event_handlers(self): | |
"""Subscribe to events that need to be sent to clients""" | |
# State events | |
self.event_bus.subscribe(EventType.STATE_TRANSITION, self._handle_state_transition) | |
# STT events | |
self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) | |
self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) | |
self.event_bus.subscribe(EventType.STT_STOPPED, self._handle_stt_stopped) | |
# TTS events | |
self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_started) | |
self.event_bus.subscribe(EventType.TTS_CHUNK_READY, self._handle_tts_chunk) | |
self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) | |
# LLM events | |
self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response) | |
# Error events | |
self.event_bus.subscribe(EventType.RECOVERABLE_ERROR, self._handle_error) | |
self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_error) | |
async def connect(self, websocket: WebSocket, session_id: str): | |
"""Accept new WebSocket connection""" | |
await websocket.accept() | |
# Check for existing connection | |
if session_id in self.connections: | |
log_warning( | |
f"β οΈ Existing connection for session, closing old one", | |
session_id=session_id | |
) | |
await self.disconnect(session_id) | |
# Create connection wrapper | |
connection = WebSocketConnection(websocket, session_id) | |
self.connections[session_id] = connection | |
# Create message queue | |
self.message_queues[session_id] = asyncio.Queue() | |
log_info( | |
f"β WebSocket connected", | |
session_id=session_id, | |
total_connections=len(self.connections) | |
) | |
# Publish connection event | |
await self.event_bus.publish(Event( | |
type=EventType.WEBSOCKET_CONNECTED, | |
session_id=session_id, | |
data={} | |
)) | |
async def disconnect(self, session_id: str): | |
"""Disconnect WebSocket connection""" | |
connection = self.connections.get(session_id) | |
if connection: | |
await connection.close() | |
del self.connections[session_id] | |
# Remove message queue | |
if session_id in self.message_queues: | |
del self.message_queues[session_id] | |
log_info( | |
f"π WebSocket disconnected", | |
session_id=session_id, | |
total_connections=len(self.connections) | |
) | |
# Publish disconnection event | |
await self.event_bus.publish(Event( | |
type=EventType.WEBSOCKET_DISCONNECTED, | |
session_id=session_id, | |
data={} | |
)) | |
# β Session'Δ± da sonlandΔ±r | |
await self.event_bus.publish(Event( | |
type=EventType.SESSION_ENDED, | |
session_id=session_id, | |
data={"reason": "websocket_disconnected"} | |
)) | |
async def handle_connection(self, websocket: WebSocket, session_id: str): | |
"""Handle WebSocket connection lifecycle""" | |
try: | |
# Connect | |
await self.connect(websocket, session_id) | |
# Create tasks for bidirectional communication | |
receive_task = asyncio.create_task(self._receive_messages(session_id)) | |
send_task = asyncio.create_task(self._send_messages(session_id)) | |
# Wait for either task to complete | |
done, pending = await asyncio.wait( | |
[receive_task, send_task], | |
return_when=asyncio.FIRST_COMPLETED | |
) | |
# Cancel pending tasks | |
for task in pending: | |
task.cancel() | |
try: | |
await task | |
except asyncio.CancelledError: | |
pass | |
except WebSocketDisconnect: | |
log_info(f"WebSocket disconnected normally", session_id=session_id) | |
except Exception as e: | |
log_error( | |
f"β WebSocket error", | |
session_id=session_id, | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
# Publish error event | |
await self.event_bus.publish(Event( | |
type=EventType.WEBSOCKET_ERROR, | |
session_id=session_id, | |
data={ | |
"error_type": "websocket_error", | |
"message": str(e) | |
} | |
)) | |
finally: | |
# Ensure disconnection | |
await self.disconnect(session_id) | |
async def _receive_messages(self, session_id: str): | |
"""Receive messages from client""" | |
connection = self.connections.get(session_id) | |
if not connection: | |
return | |
try: | |
while connection.is_active: | |
# Receive message | |
message = await connection.receive_json() | |
log_debug( | |
f"π¨ Received message", | |
session_id=session_id, | |
message_type=message.get("type") | |
) | |
# Route message based on type | |
await self._route_client_message(session_id, message) | |
except WebSocketDisconnect: | |
log_info(f"Client disconnected", session_id=session_id) | |
except Exception as e: | |
log_error( | |
f"β Error receiving messages", | |
session_id=session_id, | |
error=str(e) | |
) | |
raise | |
async def _send_messages(self, session_id: str): | |
"""Send queued messages to client""" | |
connection = self.connections.get(session_id) | |
queue = self.message_queues.get(session_id) | |
if not connection or not queue: | |
return | |
try: | |
while connection.is_active: | |
# Wait for message with timeout | |
try: | |
message = await asyncio.wait_for(queue.get(), timeout=30.0) | |
# Send to client | |
await connection.send_json(message) | |
log_debug( | |
f"π€ Sent message", | |
session_id=session_id, | |
message_type=message.get("type") | |
) | |
except asyncio.TimeoutError: | |
# Send ping to keep connection alive | |
await connection.send_json({"type": "ping"}) | |
except Exception as e: | |
log_error( | |
f"β Error sending messages", | |
session_id=session_id, | |
error=str(e) | |
) | |
raise | |
async def _route_client_message(self, session_id: str, message: dict): | |
"""Route message from client to appropriate handler""" | |
message_type = message.get("type") | |
if message_type == "audio_chunk": | |
# Audio data from client | |
audio_data_base64 = message.get("data") | |
if audio_data_base64: | |
# Debug iΓ§in audio analizi | |
try: | |
import base64 | |
import struct | |
# Base64'ten binary'ye Γ§evir | |
audio_data = base64.b64decode(audio_data_base64) | |
# Session iΓ§in debug counter | |
if not hasattr(self, 'audio_debug_counters'): | |
self.audio_debug_counters = {} | |
if session_id not in self.audio_debug_counters: | |
self.audio_debug_counters[session_id] = 0 | |
# Δ°lk 5 chunk iΓ§in detaylΔ± log | |
if self.audio_debug_counters[session_id] < 5: | |
log_info(f"π Audio chunk analysis #{self.audio_debug_counters[session_id]}", | |
session_id=session_id, | |
size_bytes=len(audio_data), | |
base64_size=len(audio_data_base64)) | |
# Δ°lk 20 byte'Δ± hex olarak gΓΆster | |
if len(audio_data) >= 20: | |
log_debug(f" First 20 bytes (hex): {audio_data[:20].hex()}") | |
# Linear16 (little-endian int16) olarak yorumla | |
samples = struct.unpack('<10h', audio_data[:20]) | |
log_debug(f" First 10 samples: {samples}") | |
log_debug(f" Max amplitude (first 10): {max(abs(s) for s in samples)}") | |
# TΓΌm chunk'Δ± analiz et | |
total_samples = len(audio_data) // 2 | |
if total_samples > 0: | |
all_samples = struct.unpack(f'<{total_samples}h', audio_data[:total_samples*2]) | |
max_amp = max(abs(s) for s in all_samples) | |
avg_amp = sum(abs(s) for s in all_samples) / total_samples | |
# Sessizlik kontrolΓΌ | |
silent = max_amp < 100 # Linear16 iΓ§in dΓΌΕΓΌk eΕik | |
log_info(f" Audio stats - Max: {max_amp}, Avg: {avg_amp:.1f}, Silent: {silent}") | |
# EΔer Γ§ok sessizse uyar | |
if max_amp < 50: | |
log_warning(f"β οΈ Very low audio level detected! Max amplitude: {max_amp}") | |
self.audio_debug_counters[session_id] += 1 | |
except Exception as e: | |
log_error(f"Error analyzing audio chunk: {e}") | |
# Audio data from client | |
await self.event_bus.publish(Event( | |
type=EventType.AUDIO_CHUNK_RECEIVED, | |
session_id=session_id, | |
data={ | |
"audio_data": message.get("data"), | |
"timestamp": message.get("timestamp") | |
} | |
)) | |
elif message_type == "control": | |
# Control messages | |
action = message.get("action") | |
config = message.get("config", {}) | |
if action == "start_conversation": | |
# Yeni action: Mevcut session iΓ§in conversation baΕlat | |
log_info(f"π€ Starting conversation for session | session_id={session_id}") | |
await self.event_bus.publish(Event( | |
type=EventType.CONVERSATION_STARTED, | |
session_id=session_id, | |
data={ | |
"config": config, | |
"continuous_listening": config.get("continuous_listening", True) | |
} | |
)) | |
# Send confirmation to client | |
await self.send_message(session_id, { | |
"type": "conversation_started", | |
"message": "Conversation started successfully" | |
}) | |
elif action == "stop_conversation": | |
await self.event_bus.publish(Event( | |
type=EventType.CONVERSATION_ENDED, | |
session_id=session_id, | |
data={"reason": "user_request"} | |
)) | |
elif action == "start_session": | |
# Bu artΔ±k kullanΔ±lmamalΔ± | |
log_warning(f"β οΈ Deprecated start_session action received | session_id={session_id}") | |
# Yine de iΕle ama conversation_started olarak | |
await self.event_bus.publish(Event( | |
type=EventType.CONVERSATION_STARTED, | |
session_id=session_id, | |
data=config | |
)) | |
elif action == "stop_session": | |
await self.event_bus.publish(Event( | |
type=EventType.CONVERSATION_ENDED, | |
session_id=session_id, | |
data={"reason": "user_request"} | |
)) | |
elif action == "end_session": | |
await self.event_bus.publish(Event( | |
type=EventType.SESSION_ENDED, | |
session_id=session_id, | |
data={"reason": "user_request"} | |
)) | |
elif action == "audio_ended": | |
await self.event_bus.publish(Event( | |
type=EventType.AUDIO_PLAYBACK_COMPLETED, | |
session_id=session_id, | |
data={} | |
)) | |
else: | |
log_warning( | |
f"β οΈ Unknown control action", | |
session_id=session_id, | |
action=action | |
) | |
elif message_type == "ping": | |
# Respond to ping | |
await self.send_message(session_id, {"type": "pong"}) | |
else: | |
log_warning( | |
f"β οΈ Unknown message type", | |
session_id=session_id, | |
message_type=message_type | |
) | |
async def send_message(self, session_id: str, message: dict): | |
"""Queue message for sending to client""" | |
queue = self.message_queues.get(session_id) | |
if queue: | |
await queue.put(message) | |
else: | |
log_warning( | |
f"β οΈ No queue for session", | |
session_id=session_id | |
) | |
async def broadcast_to_session(self, session_id: str, message: dict): | |
"""Send message immediately (bypass queue)""" | |
connection = self.connections.get(session_id) | |
if connection and connection.is_active: | |
await connection.send_json(message) | |
# Event handlers for sending messages to clients | |
async def _handle_state_transition(self, event: Event): | |
"""Send state transition to client""" | |
await self.send_message(event.session_id, { | |
"type": "state_change", | |
"from": event.data.get("old_state"), | |
"to": event.data.get("new_state") | |
}) | |
async def _handle_stt_ready(self, event: Event): | |
"""Send STT ready signal to client""" | |
await self.send_message(event.session_id, { | |
"type": "stt_ready", | |
"message": "STT is ready to receive audio" | |
}) | |
async def _handle_stt_stopped(self, event: Event): | |
"""Send STT stopped signal to client""" | |
session_id = event.session_id | |
stop_recording = event.data.get("stop_recording", False) | |
message = { | |
"type": "stt_stopped", | |
"message": "STT stopped", | |
"stop_recording": stop_recording | |
} | |
if stop_recording: | |
message["message"] = "STT stopped, please stop sending audio immediately" | |
log_info(f"π€ Sending stop recording command to frontend", session_id=session_id) | |
await self.send_message(session_id, message) | |
async def _handle_stt_result(self, event: Event): | |
"""Send STT result to client""" | |
# Batch mode'da sadece final result gelir | |
await self.send_message(event.session_id, { | |
"type": "transcription", | |
"text": event.data.get("text", ""), | |
"is_final": True, # Always true in batch mode | |
"confidence": event.data.get("confidence", 0.0) | |
}) | |
async def _handle_tts_started(self, event: Event): | |
"""Send assistant message when TTS starts""" | |
if event.data.get("is_welcome"): | |
# Send welcome message to client | |
await self.send_message(event.session_id, { | |
"type": "assistant_response", | |
"text": event.data.get("text", ""), | |
"is_welcome": True | |
}) | |
async def _handle_tts_chunk(self, event: Event): | |
"""Send TTS audio chunk to client""" | |
await self.send_message(event.session_id, { | |
"type": "tts_audio", | |
"data": event.data.get("audio_data"), | |
"chunk_index": event.data.get("chunk_index"), | |
"total_chunks": event.data.get("total_chunks"), | |
"is_last": event.data.get("is_last", False), | |
"mime_type": event.data.get("mime_type", "audio/mpeg") | |
}) | |
async def _handle_tts_completed(self, event: Event): | |
"""Notify client that TTS is complete""" | |
# Client knows from is_last flag in chunks | |
pass | |
async def _handle_llm_response(self, event: Event): | |
"""Send LLM response to client""" | |
await self.send_message(event.session_id, { | |
"type": "assistant_response", | |
"text": event.data.get("text", ""), | |
"is_welcome": event.data.get("is_welcome", False) | |
}) | |
async def _handle_error(self, event: Event): | |
"""Send error to client""" | |
error_type = event.data.get("error_type", "unknown") | |
message = event.data.get("message", "An error occurred") | |
await self.send_message(event.session_id, { | |
"type": "error", | |
"error_type": error_type, | |
"message": message, | |
"details": event.data.get("details", {}) | |
}) | |
def get_connection_count(self) -> int: | |
"""Get number of active connections""" | |
return len(self.connections) | |
def get_session_connections(self) -> Set[str]: | |
"""Get all active session IDs""" | |
return set(self.connections.keys()) | |
async def close_all_connections(self): | |
"""Close all active connections""" | |
session_ids = list(self.connections.keys()) | |
for session_id in session_ids: | |
await self.disconnect(session_id) |