flare / websocket_manager.py
ciyidogan's picture
Upload 8 files
e8a19b3 verified
raw
history blame
14.6 kB
"""
WebSocket Manager for Flare
===========================
Manages WebSocket connections and message routing
"""
import asyncio
from typing import Dict, Optional, Set
from fastapi import WebSocket, WebSocketDisconnect
import json
from datetime import datetime
import traceback
from event_bus import EventBus, Event, EventType
from logger import log_info, log_error, log_debug, log_warning
class WebSocketConnection:
"""Wrapper for WebSocket connection with metadata"""
def __init__(self, websocket: WebSocket, session_id: str):
self.websocket = websocket
self.session_id = session_id
self.connected_at = datetime.utcnow()
self.last_activity = datetime.utcnow()
self.is_active = True
async def send_json(self, data: dict):
"""Send JSON data to client"""
try:
if self.is_active:
await self.websocket.send_json(data)
self.last_activity = datetime.utcnow()
except Exception as e:
log_error(
f"❌ Failed to send message",
session_id=self.session_id,
error=str(e)
)
self.is_active = False
raise
async def receive_json(self) -> dict:
"""Receive JSON data from client"""
try:
data = await self.websocket.receive_json()
self.last_activity = datetime.utcnow()
return data
except WebSocketDisconnect:
self.is_active = False
raise
except Exception as e:
log_error(
f"❌ Failed to receive message",
session_id=self.session_id,
error=str(e)
)
self.is_active = False
raise
async def close(self):
"""Close the connection"""
try:
self.is_active = False
await self.websocket.close()
except:
pass
class WebSocketManager:
"""Manages WebSocket connections and routing"""
def __init__(self, event_bus: EventBus):
self.event_bus = event_bus
self.connections: Dict[str, WebSocketConnection] = {}
self.message_queues: Dict[str, asyncio.Queue] = {}
self._setup_event_handlers()
def _setup_event_handlers(self):
"""Subscribe to events that need to be sent to clients"""
# State events
self.event_bus.subscribe(EventType.STATE_TRANSITION, self._handle_state_transition)
# STT events
self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready)
self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result)
# TTS events
self.event_bus.subscribe(EventType.TTS_CHUNK_READY, self._handle_tts_chunk)
self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed)
# LLM events
self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response)
# Error events
self.event_bus.subscribe(EventType.RECOVERABLE_ERROR, self._handle_error)
self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_error)
async def connect(self, websocket: WebSocket, session_id: str):
"""Accept new WebSocket connection"""
await websocket.accept()
# Check for existing connection
if session_id in self.connections:
log_warning(
f"⚠️ Existing connection for session, closing old one",
session_id=session_id
)
await self.disconnect(session_id)
# Create connection wrapper
connection = WebSocketConnection(websocket, session_id)
self.connections[session_id] = connection
# Create message queue
self.message_queues[session_id] = asyncio.Queue()
log_info(
f"βœ… WebSocket connected",
session_id=session_id,
total_connections=len(self.connections)
)
# Publish connection event
await self.event_bus.publish(Event(
type=EventType.WEBSOCKET_CONNECTED,
session_id=session_id,
data={}
))
async def disconnect(self, session_id: str):
"""Disconnect WebSocket connection"""
connection = self.connections.get(session_id)
if connection:
await connection.close()
del self.connections[session_id]
# Remove message queue
if session_id in self.message_queues:
del self.message_queues[session_id]
log_info(
f"πŸ”Œ WebSocket disconnected",
session_id=session_id,
total_connections=len(self.connections)
)
# Publish disconnection event
await self.event_bus.publish(Event(
type=EventType.WEBSOCKET_DISCONNECTED,
session_id=session_id,
data={}
))
async def handle_connection(self, websocket: WebSocket, session_id: str):
"""Handle WebSocket connection lifecycle"""
try:
# Connect
await self.connect(websocket, session_id)
# Create tasks for bidirectional communication
receive_task = asyncio.create_task(self._receive_messages(session_id))
send_task = asyncio.create_task(self._send_messages(session_id))
# Wait for either task to complete
done, pending = await asyncio.wait(
[receive_task, send_task],
return_when=asyncio.FIRST_COMPLETED
)
# Cancel pending tasks
for task in pending:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except WebSocketDisconnect:
log_info(f"WebSocket disconnected normally", session_id=session_id)
except Exception as e:
log_error(
f"❌ WebSocket error",
session_id=session_id,
error=str(e),
traceback=traceback.format_exc()
)
# Publish error event
await self.event_bus.publish(Event(
type=EventType.WEBSOCKET_ERROR,
session_id=session_id,
data={
"error_type": "websocket_error",
"message": str(e)
}
))
finally:
# Ensure disconnection
await self.disconnect(session_id)
async def _receive_messages(self, session_id: str):
"""Receive messages from client"""
connection = self.connections.get(session_id)
if not connection:
return
try:
while connection.is_active:
# Receive message
message = await connection.receive_json()
log_debug(
f"πŸ“¨ Received message",
session_id=session_id,
message_type=message.get("type")
)
# Route message based on type
await self._route_client_message(session_id, message)
except WebSocketDisconnect:
log_info(f"Client disconnected", session_id=session_id)
except Exception as e:
log_error(
f"❌ Error receiving messages",
session_id=session_id,
error=str(e)
)
raise
async def _send_messages(self, session_id: str):
"""Send queued messages to client"""
connection = self.connections.get(session_id)
queue = self.message_queues.get(session_id)
if not connection or not queue:
return
try:
while connection.is_active:
# Wait for message with timeout
try:
message = await asyncio.wait_for(queue.get(), timeout=30.0)
# Send to client
await connection.send_json(message)
log_debug(
f"πŸ“€ Sent message",
session_id=session_id,
message_type=message.get("type")
)
except asyncio.TimeoutError:
# Send ping to keep connection alive
await connection.send_json({"type": "ping"})
except Exception as e:
log_error(
f"❌ Error sending messages",
session_id=session_id,
error=str(e)
)
raise
async def _route_client_message(self, session_id: str, message: dict):
"""Route message from client to appropriate handler"""
message_type = message.get("type")
if message_type == "audio_chunk":
# Audio data from client
await self.event_bus.publish(Event(
type=EventType.AUDIO_CHUNK_RECEIVED,
session_id=session_id,
data={
"audio_data": message.get("data"),
"timestamp": message.get("timestamp")
}
))
elif message_type == "control":
# Control messages
action = message.get("action")
if action == "start_session":
await self.event_bus.publish(Event(
type=EventType.SESSION_STARTED,
session_id=session_id,
data=message.get("config", {})
))
elif action == "end_session":
await self.event_bus.publish(Event(
type=EventType.SESSION_ENDED,
session_id=session_id,
data={"reason": "user_request"}
))
elif action == "audio_ended":
await self.event_bus.publish(Event(
type=EventType.AUDIO_PLAYBACK_COMPLETED,
session_id=session_id,
data={}
))
elif message_type == "ping":
# Respond to ping
await self.send_message(session_id, {"type": "pong"})
else:
log_warning(
f"⚠️ Unknown message type",
session_id=session_id,
message_type=message_type
)
async def send_message(self, session_id: str, message: dict):
"""Queue message for sending to client"""
queue = self.message_queues.get(session_id)
if queue:
await queue.put(message)
else:
log_warning(
f"⚠️ No queue for session",
session_id=session_id
)
async def broadcast_to_session(self, session_id: str, message: dict):
"""Send message immediately (bypass queue)"""
connection = self.connections.get(session_id)
if connection and connection.is_active:
await connection.send_json(message)
# Event handlers for sending messages to clients
async def _handle_state_transition(self, event: Event):
"""Send state transition to client"""
await self.send_message(event.session_id, {
"type": "state_change",
"from": event.data.get("from_state"),
"to": event.data.get("to_state")
})
async def _handle_stt_ready(self, event: Event):
"""Send STT ready signal to client"""
await self.send_message(event.session_id, {
"type": "stt_ready",
"message": "STT is ready to receive audio"
})
async def _handle_stt_result(self, event: Event):
"""Send STT result to client"""
await self.send_message(event.session_id, {
"type": "transcription",
"text": event.data.get("text", ""),
"is_final": event.data.get("is_final", False),
"confidence": event.data.get("confidence", 0.0)
})
async def _handle_tts_chunk(self, event: Event):
"""Send TTS audio chunk to client"""
await self.send_message(event.session_id, {
"type": "tts_audio",
"data": event.data.get("audio_data"),
"chunk_index": event.data.get("chunk_index"),
"total_chunks": event.data.get("total_chunks"),
"is_last": event.data.get("is_last", False),
"mime_type": event.data.get("mime_type", "audio/mpeg")
})
async def _handle_tts_completed(self, event: Event):
"""Notify client that TTS is complete"""
# Client knows from is_last flag in chunks
pass
async def _handle_llm_response(self, event: Event):
"""Send LLM response to client"""
await self.send_message(event.session_id, {
"type": "assistant_response",
"text": event.data.get("text", ""),
"is_welcome": event.data.get("is_welcome", False)
})
async def _handle_error(self, event: Event):
"""Send error to client"""
error_type = event.data.get("error_type", "unknown")
message = event.data.get("message", "An error occurred")
await self.send_message(event.session_id, {
"type": "error",
"error_type": error_type,
"message": message,
"details": event.data.get("details", {})
})
def get_connection_count(self) -> int:
"""Get number of active connections"""
return len(self.connections)
def get_session_connections(self) -> Set[str]:
"""Get all active session IDs"""
return set(self.connections.keys())
async def close_all_connections(self):
"""Close all active connections"""
session_ids = list(self.connections.keys())
for session_id in session_ids:
await self.disconnect(session_id)