""" WebSocket Manager for Flare =========================== Manages WebSocket connections and message routing """ import asyncio from typing import Dict, Optional, Set from fastapi import WebSocket, WebSocketDisconnect import json from datetime import datetime import traceback from event_bus import EventBus, Event, EventType from utils.logger import log_info, log_error, log_debug, log_warning class WebSocketConnection: """Wrapper for WebSocket connection with metadata""" def __init__(self, websocket: WebSocket, session_id: str): self.websocket = websocket self.session_id = session_id self.connected_at = datetime.utcnow() self.last_activity = datetime.utcnow() self.is_active = True async def send_json(self, data: dict): """Send JSON data to client""" try: if self.is_active: await self.websocket.send_json(data) self.last_activity = datetime.utcnow() except Exception as e: log_error( f"❌ Failed to send message", session_id=self.session_id, error=str(e) ) self.is_active = False raise async def receive_json(self) -> dict: """Receive JSON data from client""" try: data = await self.websocket.receive_json() self.last_activity = datetime.utcnow() return data except WebSocketDisconnect: self.is_active = False raise except Exception as e: log_error( f"❌ Failed to receive message", session_id=self.session_id, error=str(e) ) self.is_active = False raise async def close(self): """Close the connection""" try: self.is_active = False await self.websocket.close() except: pass class WebSocketManager: """Manages WebSocket connections and routing""" def __init__(self, event_bus: EventBus): self.event_bus = event_bus self.connections: Dict[str, WebSocketConnection] = {} self.message_queues: Dict[str, asyncio.Queue] = {} self._setup_event_handlers() def _setup_event_handlers(self): """Subscribe to events that need to be sent to clients""" # State events self.event_bus.subscribe(EventType.STATE_TRANSITION, self._handle_state_transition) # STT events self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) # TTS events self.event_bus.subscribe(EventType.TTS_CHUNK_READY, self._handle_tts_chunk) self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) # LLM events self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response) # Error events self.event_bus.subscribe(EventType.RECOVERABLE_ERROR, self._handle_error) self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_error) async def connect(self, websocket: WebSocket, session_id: str): """Accept new WebSocket connection""" await websocket.accept() # Check for existing connection if session_id in self.connections: log_warning( f"⚠️ Existing connection for session, closing old one", session_id=session_id ) await self.disconnect(session_id) # Create connection wrapper connection = WebSocketConnection(websocket, session_id) self.connections[session_id] = connection # Create message queue self.message_queues[session_id] = asyncio.Queue() log_info( f"✅ WebSocket connected", session_id=session_id, total_connections=len(self.connections) ) # Publish connection event await self.event_bus.publish(Event( type=EventType.WEBSOCKET_CONNECTED, session_id=session_id, data={} )) async def disconnect(self, session_id: str): """Disconnect WebSocket connection""" connection = self.connections.get(session_id) if connection: await connection.close() del self.connections[session_id] # Remove message queue if session_id in self.message_queues: del self.message_queues[session_id] log_info( f"🔌 WebSocket disconnected", session_id=session_id, total_connections=len(self.connections) ) # Publish disconnection event await self.event_bus.publish(Event( type=EventType.WEBSOCKET_DISCONNECTED, session_id=session_id, data={} )) async def handle_connection(self, websocket: WebSocket, session_id: str): """Handle WebSocket connection lifecycle""" try: # Connect await self.connect(websocket, session_id) # Create tasks for bidirectional communication receive_task = asyncio.create_task(self._receive_messages(session_id)) send_task = asyncio.create_task(self._send_messages(session_id)) # Wait for either task to complete done, pending = await asyncio.wait( [receive_task, send_task], return_when=asyncio.FIRST_COMPLETED ) # Cancel pending tasks for task in pending: task.cancel() try: await task except asyncio.CancelledError: pass except WebSocketDisconnect: log_info(f"WebSocket disconnected normally", session_id=session_id) except Exception as e: log_error( f"❌ WebSocket error", session_id=session_id, error=str(e), traceback=traceback.format_exc() ) # Publish error event await self.event_bus.publish(Event( type=EventType.WEBSOCKET_ERROR, session_id=session_id, data={ "error_type": "websocket_error", "message": str(e) } )) finally: # Ensure disconnection await self.disconnect(session_id) async def _receive_messages(self, session_id: str): """Receive messages from client""" connection = self.connections.get(session_id) if not connection: return try: while connection.is_active: # Receive message message = await connection.receive_json() log_debug( f"📨 Received message", session_id=session_id, message_type=message.get("type") ) # Route message based on type await self._route_client_message(session_id, message) except WebSocketDisconnect: log_info(f"Client disconnected", session_id=session_id) except Exception as e: log_error( f"❌ Error receiving messages", session_id=session_id, error=str(e) ) raise async def _send_messages(self, session_id: str): """Send queued messages to client""" connection = self.connections.get(session_id) queue = self.message_queues.get(session_id) if not connection or not queue: return try: while connection.is_active: # Wait for message with timeout try: message = await asyncio.wait_for(queue.get(), timeout=30.0) # Send to client await connection.send_json(message) log_debug( f"📤 Sent message", session_id=session_id, message_type=message.get("type") ) except asyncio.TimeoutError: # Send ping to keep connection alive await connection.send_json({"type": "ping"}) except Exception as e: log_error( f"❌ Error sending messages", session_id=session_id, error=str(e) ) raise async def _route_client_message(self, session_id: str, message: dict): """Route message from client to appropriate handler""" message_type = message.get("type") if message_type == "audio_chunk": # Audio data from client await self.event_bus.publish(Event( type=EventType.AUDIO_CHUNK_RECEIVED, session_id=session_id, data={ "audio_data": message.get("data"), "timestamp": message.get("timestamp") } )) elif message_type == "control": # Control messages action = message.get("action") config = message.get("config", {}) if action == "start_conversation": # Yeni action: Mevcut session için conversation başlat log_info(f"🎤 Starting conversation for session | session_id={session_id}") await self.event_bus.publish(Event( type=EventType.CONVERSATION_STARTED, session_id=session_id, data={ "config": config, "continuous_listening": config.get("continuous_listening", True) } )) # Send confirmation to client await self.send_message(session_id, { "type": "conversation_started", "message": "Conversation started successfully" }) elif action == "start_session": # Bu artık kullanılmamalı log_warning(f"⚠️ Deprecated start_session action received | session_id={session_id}") # Yine de işle ama conversation_started olarak await self.event_bus.publish(Event( type=EventType.CONVERSATION_STARTED, session_id=session_id, data=config )) elif action == "stop_session": await self.event_bus.publish(Event( type=EventType.CONVERSATION_ENDED, session_id=session_id, data={"reason": "user_request"} )) elif action == "end_session": await self.event_bus.publish(Event( type=EventType.SESSION_ENDED, session_id=session_id, data={"reason": "user_request"} )) elif action == "audio_ended": await self.event_bus.publish(Event( type=EventType.AUDIO_PLAYBACK_COMPLETED, session_id=session_id, data={} )) else: log_warning( f"⚠️ Unknown control action", session_id=session_id, action=action ) elif message_type == "ping": # Respond to ping await self.send_message(session_id, {"type": "pong"}) else: log_warning( f"⚠️ Unknown message type", session_id=session_id, message_type=message_type ) async def send_message(self, session_id: str, message: dict): """Queue message for sending to client""" queue = self.message_queues.get(session_id) if queue: await queue.put(message) else: log_warning( f"⚠️ No queue for session", session_id=session_id ) async def broadcast_to_session(self, session_id: str, message: dict): """Send message immediately (bypass queue)""" connection = self.connections.get(session_id) if connection and connection.is_active: await connection.send_json(message) # Event handlers for sending messages to clients async def _handle_state_transition(self, event: Event): """Send state transition to client""" await self.send_message(event.session_id, { "type": "state_change", "from": event.data.get("from_state"), "to": event.data.get("to_state") }) async def _handle_stt_ready(self, event: Event): """Send STT ready signal to client""" await self.send_message(event.session_id, { "type": "stt_ready", "message": "STT is ready to receive audio" }) async def _handle_stt_result(self, event: Event): """Send STT result to client""" await self.send_message(event.session_id, { "type": "transcription", "text": event.data.get("text", ""), "is_final": event.data.get("is_final", False), "confidence": event.data.get("confidence", 0.0) }) async def _handle_tts_chunk(self, event: Event): """Send TTS audio chunk to client""" await self.send_message(event.session_id, { "type": "tts_audio", "data": event.data.get("audio_data"), "chunk_index": event.data.get("chunk_index"), "total_chunks": event.data.get("total_chunks"), "is_last": event.data.get("is_last", False), "mime_type": event.data.get("mime_type", "audio/mpeg") }) async def _handle_tts_completed(self, event: Event): """Notify client that TTS is complete""" # Client knows from is_last flag in chunks pass async def _handle_llm_response(self, event: Event): """Send LLM response to client""" await self.send_message(event.session_id, { "type": "assistant_response", "text": event.data.get("text", ""), "is_welcome": event.data.get("is_welcome", False) }) async def _handle_error(self, event: Event): """Send error to client""" error_type = event.data.get("error_type", "unknown") message = event.data.get("message", "An error occurred") await self.send_message(event.session_id, { "type": "error", "error_type": error_type, "message": message, "details": event.data.get("details", {}) }) def get_connection_count(self) -> int: """Get number of active connections""" return len(self.connections) def get_session_connections(self) -> Set[str]: """Get all active session IDs""" return set(self.connections.keys()) async def close_all_connections(self): """Close all active connections""" session_ids = list(self.connections.keys()) for session_id in session_ids: await self.disconnect(session_id)