|
from uuid import UUID |
|
import asyncio |
|
from fastapi import WebSocket |
|
from fastapi.websockets import WebSocketDisconnect |
|
from starlette.websockets import WebSocketState |
|
import logging |
|
from typing import Any |
|
from util import ParamsModel |
|
|
|
Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]] |
|
|
|
|
|
class ServerFullException(Exception): |
|
"""Exception raised when the server is full.""" |
|
|
|
pass |
|
|
|
|
|
class ConnectionManager: |
|
def __init__(self): |
|
self.active_connections: Connections = {} |
|
|
|
async def connect( |
|
self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0 |
|
): |
|
await websocket.accept() |
|
user_count = self.get_user_count() |
|
print(f"User count: {user_count}") |
|
if max_queue_size > 0 and user_count >= max_queue_size: |
|
print("Server is full") |
|
await websocket.send_json({"status": "error", "message": "Server is full"}) |
|
await websocket.close() |
|
raise ServerFullException("Server is full") |
|
print(f"New user connected: {user_id}") |
|
self.active_connections[user_id] = { |
|
"websocket": websocket, |
|
"queue": asyncio.Queue(), |
|
} |
|
await websocket.send_json( |
|
{"status": "connected", "message": "Connected"}, |
|
) |
|
await websocket.send_json({"status": "wait"}) |
|
await websocket.send_json({"status": "send_frame"}) |
|
|
|
def check_user(self, user_id: UUID) -> bool: |
|
return user_id in self.active_connections |
|
|
|
async def update_data(self, user_id: UUID, new_data: ParamsModel): |
|
user_session = self.active_connections.get(user_id) |
|
if user_session: |
|
queue = user_session["queue"] |
|
await queue.put(new_data) |
|
|
|
async def get_latest_data(self, user_id: UUID) -> ParamsModel | None: |
|
user_session = self.active_connections.get(user_id) |
|
if user_session: |
|
queue = user_session["queue"] |
|
try: |
|
return await queue.get() |
|
except asyncio.QueueEmpty: |
|
return None |
|
return None |
|
|
|
def delete_user(self, user_id: UUID): |
|
user_session = self.active_connections.pop(user_id, None) |
|
if user_session: |
|
queue = user_session["queue"] |
|
while not queue.empty(): |
|
try: |
|
queue.get_nowait() |
|
except asyncio.QueueEmpty: |
|
continue |
|
|
|
def get_user_count(self) -> int: |
|
return len(self.active_connections) |
|
|
|
def get_websocket(self, user_id: UUID) -> WebSocket | None: |
|
user_session = self.active_connections.get(user_id) |
|
if user_session: |
|
websocket = user_session["websocket"] |
|
|
|
|
|
if (websocket.client_state == WebSocketState.CONNECTED and |
|
websocket.application_state == WebSocketState.CONNECTED): |
|
return user_session["websocket"] |
|
return None |
|
|
|
async def disconnect(self, user_id: UUID): |
|
|
|
if user_id not in self.active_connections: |
|
return |
|
|
|
|
|
user_session = self.active_connections.get(user_id) |
|
if user_session and "websocket" in user_session: |
|
websocket = user_session["websocket"] |
|
try: |
|
|
|
if (websocket.client_state != WebSocketState.DISCONNECTED and |
|
websocket.application_state != WebSocketState.DISCONNECTED): |
|
await websocket.close() |
|
except Exception as e: |
|
logging.error(f"Error closing websocket for {user_id}: {e}") |
|
|
|
|
|
self.delete_user(user_id) |
|
|
|
async def send_json(self, user_id: UUID, data: dict): |
|
try: |
|
websocket = self.get_websocket(user_id) |
|
if websocket: |
|
try: |
|
await websocket.send_json(data) |
|
except RuntimeError as e: |
|
error_msg = str(e) |
|
if any(err in error_msg for err in [ |
|
"WebSocket is not connected", |
|
"Cannot call \"send\" once a close message has been sent", |
|
"Cannot call \"receive\" once a close message has been sent", |
|
"WebSocket is disconnected"]): |
|
|
|
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}") |
|
await self.disconnect(user_id) |
|
else: |
|
logging.error(f"Runtime error in send_json: {e}") |
|
except WebSocketDisconnect as disconnect_error: |
|
|
|
code = disconnect_error.code |
|
if code == 1006: |
|
logging.info(f"WebSocket abnormally closed for user {user_id} during send: Connection was closed without a proper close handshake") |
|
else: |
|
logging.info(f"WebSocket disconnected for user {user_id} with code {code} during send: {disconnect_error.reason}") |
|
|
|
|
|
if user_id in self.active_connections: |
|
await self.disconnect(user_id) |
|
except Exception as e: |
|
logging.error(f"Error: Send json: {e}") |
|
|
|
if user_id in self.active_connections: |
|
await self.disconnect(user_id) |
|
|
|
async def receive_json(self, user_id: UUID) -> dict | None: |
|
try: |
|
websocket = self.get_websocket(user_id) |
|
if websocket: |
|
try: |
|
|
|
try: |
|
data = await websocket.receive_json() |
|
|
|
if not isinstance(data, dict): |
|
logging.error(f"Expected dict but received {type(data)} from user {user_id}: {data}") |
|
return None |
|
return data |
|
except ValueError as json_err: |
|
|
|
logging.error(f"JSON parsing error for user {user_id}: {json_err}") |
|
return None |
|
except RuntimeError as e: |
|
error_msg = str(e) |
|
if any(err in error_msg for err in [ |
|
"WebSocket is not connected", |
|
"Cannot call \"send\" once a close message has been sent", |
|
"Cannot call \"receive\" once a close message has been sent", |
|
"WebSocket is disconnected"]): |
|
|
|
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}") |
|
await self.disconnect(user_id) |
|
else: |
|
logging.error(f"Runtime error in receive_json: {e}") |
|
return None |
|
return None |
|
except WebSocketDisconnect as disconnect_error: |
|
|
|
code = disconnect_error.code |
|
if code == 1006: |
|
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake") |
|
else: |
|
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}") |
|
|
|
|
|
if user_id in self.active_connections: |
|
await self.disconnect(user_id) |
|
return None |
|
except Exception as e: |
|
logging.error(f"Error: Receive json: {e}") |
|
|
|
if user_id in self.active_connections: |
|
await self.disconnect(user_id) |
|
return None |
|
|
|
async def receive_bytes(self, user_id: UUID) -> bytes | None: |
|
try: |
|
websocket = self.get_websocket(user_id) |
|
if websocket: |
|
try: |
|
return await websocket.receive_bytes() |
|
except RuntimeError as e: |
|
error_msg = str(e) |
|
if any(err in error_msg for err in [ |
|
"WebSocket is not connected", |
|
"Cannot call \"send\" once a close message has been sent", |
|
"Cannot call \"receive\" once a close message has been sent", |
|
"WebSocket is disconnected"]): |
|
|
|
logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}") |
|
await self.disconnect(user_id) |
|
else: |
|
logging.error(f"Runtime error in receive_bytes: {e}") |
|
return None |
|
return None |
|
except WebSocketDisconnect as disconnect_error: |
|
|
|
code = disconnect_error.code |
|
if code == 1006: |
|
logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake") |
|
else: |
|
logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}") |
|
|
|
|
|
if user_id in self.active_connections: |
|
await self.disconnect(user_id) |
|
return None |
|
except Exception as e: |
|
logging.error(f"Error: Receive bytes: {e}") |
|
|
|
if user_id in self.active_connections: |
|
await self.disconnect(user_id) |
|
return None |
|
|