radames's picture
refactor: fix websocket errors
9a8789a
from uuid import UUID
import asyncio
from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect
from starlette.websockets import WebSocketState
import logging
from typing import Any
from util import ParamsModel
Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]]
class ServerFullException(Exception):
"""Exception raised when the server is full."""
pass
class ConnectionManager:
def __init__(self):
self.active_connections: Connections = {}
async def connect(
self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
):
await websocket.accept()
user_count = self.get_user_count()
print(f"User count: {user_count}")
if max_queue_size > 0 and user_count >= max_queue_size:
print("Server is full")
await websocket.send_json({"status": "error", "message": "Server is full"})
await websocket.close()
raise ServerFullException("Server is full")
print(f"New user connected: {user_id}")
self.active_connections[user_id] = {
"websocket": websocket,
"queue": asyncio.Queue(),
}
await websocket.send_json(
{"status": "connected", "message": "Connected"},
)
await websocket.send_json({"status": "wait"})
await websocket.send_json({"status": "send_frame"})
def check_user(self, user_id: UUID) -> bool:
return user_id in self.active_connections
async def update_data(self, user_id: UUID, new_data: ParamsModel):
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
await queue.put(new_data)
async def get_latest_data(self, user_id: UUID) -> ParamsModel | None:
user_session = self.active_connections.get(user_id)
if user_session:
queue = user_session["queue"]
try:
return await queue.get()
except asyncio.QueueEmpty:
return None
return None
def delete_user(self, user_id: UUID):
user_session = self.active_connections.pop(user_id, None)
if user_session:
queue = user_session["queue"]
while not queue.empty():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
continue
def get_user_count(self) -> int:
return len(self.active_connections)
def get_websocket(self, user_id: UUID) -> WebSocket | None:
user_session = self.active_connections.get(user_id)
if user_session:
websocket = user_session["websocket"]
# Both client_state and application_state should be checked
# to ensure the websocket is fully connected and not closing
if (websocket.client_state == WebSocketState.CONNECTED and
websocket.application_state == WebSocketState.CONNECTED):
return user_session["websocket"]
return None
async def disconnect(self, user_id: UUID):
# First check if user is in active connections
if user_id not in self.active_connections:
return
# Get the websocket directly from active_connections to avoid get_websocket validation
user_session = self.active_connections.get(user_id)
if user_session and "websocket" in user_session:
websocket = user_session["websocket"]
try:
# Only attempt close if not already closed
if (websocket.client_state != WebSocketState.DISCONNECTED and
websocket.application_state != WebSocketState.DISCONNECTED):
await websocket.close()
except Exception as e:
logging.error(f"Error closing websocket for {user_id}: {e}")
# Always delete the user to ensure cleanup
self.delete_user(user_id)
async def send_json(self, user_id: UUID, data: dict):
try:
websocket = self.get_websocket(user_id)
if websocket:
try:
await websocket.send_json(data)
except RuntimeError as e:
error_msg = str(e)
if any(err in error_msg for err in [
"WebSocket is not connected",
"Cannot call \"send\" once a close message has been sent",
"Cannot call \"receive\" once a close message has been sent",
"WebSocket is disconnected"]):
# The websocket was disconnected or is closing
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
await self.disconnect(user_id)
else:
logging.error(f"Runtime error in send_json: {e}")
except WebSocketDisconnect as disconnect_error:
# Handle websocket disconnection event
code = disconnect_error.code
if code == 1006: # ABNORMAL_CLOSURE
logging.info(f"WebSocket abnormally closed for user {user_id} during send: Connection was closed without a proper close handshake")
else:
logging.info(f"WebSocket disconnected for user {user_id} with code {code} during send: {disconnect_error.reason}")
# Always disconnect the user
if user_id in self.active_connections:
await self.disconnect(user_id)
except Exception as e:
logging.error(f"Error: Send json: {e}")
# If any send fails, ensure the user gets removed to prevent further errors
if user_id in self.active_connections:
await self.disconnect(user_id)
async def receive_json(self, user_id: UUID) -> dict | None:
try:
websocket = self.get_websocket(user_id)
if websocket:
try:
# Receive the raw message and handle JSON parsing manually for better error handling
try:
data = await websocket.receive_json()
# Verify it's a dictionary
if not isinstance(data, dict):
logging.error(f"Expected dict but received {type(data)} from user {user_id}: {data}")
return None
return data
except ValueError as json_err:
# Specific handling for JSON parsing errors
logging.error(f"JSON parsing error for user {user_id}: {json_err}")
return None
except RuntimeError as e:
error_msg = str(e)
if any(err in error_msg for err in [
"WebSocket is not connected",
"Cannot call \"send\" once a close message has been sent",
"Cannot call \"receive\" once a close message has been sent",
"WebSocket is disconnected"]):
# The websocket was disconnected or closing
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
await self.disconnect(user_id)
else:
logging.error(f"Runtime error in receive_json: {e}")
return None
return None
except WebSocketDisconnect as disconnect_error:
# Handle websocket disconnection event (this is a clean, expected path)
code = disconnect_error.code
if code == 1006: # ABNORMAL_CLOSURE
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
else:
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
# Always disconnect the user
if user_id in self.active_connections:
await self.disconnect(user_id)
return None
except Exception as e:
logging.error(f"Error: Receive json: {e}")
# Ensure disconnection on any exception
if user_id in self.active_connections:
await self.disconnect(user_id)
return None
async def receive_bytes(self, user_id: UUID) -> bytes | None:
try:
websocket = self.get_websocket(user_id)
if websocket:
try:
return await websocket.receive_bytes()
except RuntimeError as e:
error_msg = str(e)
if any(err in error_msg for err in [
"WebSocket is not connected",
"Cannot call \"send\" once a close message has been sent",
"Cannot call \"receive\" once a close message has been sent",
"WebSocket is disconnected"]):
# The websocket was disconnected or closing
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
await self.disconnect(user_id)
else:
logging.error(f"Runtime error in receive_bytes: {e}")
return None
return None
except WebSocketDisconnect as disconnect_error:
# Handle websocket disconnection event (this is a clean, expected path)
code = disconnect_error.code
if code == 1006: # ABNORMAL_CLOSURE
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
else:
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
# Always disconnect the user
if user_id in self.active_connections:
await self.disconnect(user_id)
return None
except Exception as e:
logging.error(f"Error: Receive bytes: {e}")
# Ensure disconnection on any exception
if user_id in self.active_connections:
await self.disconnect(user_id)
return None