import asyncio import json import logging from fastapi import HTTPException, WebSocket, status from typing import Dict class InRequest: def __init__(self): self.responses: Dict[str, asyncio.Future] = {} class ConnectionManager: def __init__(self): self.available = None self.active_connections: Dict[str, WebSocket] = {} # Maps socket ID to WebSocket connection self.in_request: Dict[str, InRequest] = {} # Store pending response futures async def connect(self, socket_id: str, websocket: WebSocket): await websocket.accept() self.active_connections[socket_id] = websocket if self.available is None: self.available = socket_id return socket_id def disconnect(self, socket_id: str): if socket_id in self.active_connections: del self.active_connections[socket_id] if self.available == socket_id: self.available = None async def broadcast(self, message: str): for connection in self.active_connections.values(): await connection.send_text(message) async def receive_text(self, socket_id: str): websocket = self.active_connections.get(socket_id) if websocket: return await websocket.receive_text() else: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Socket ID {socket_id} not connected") async def send_text(self, socket_id: str, message: str): websocket = self.active_connections.get(socket_id) if websocket: await websocket.send_text(message) else: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail="WebSocket connection not found.") async def send_bytes(self, socket_id: str, binary_data: bytes): websocket = self.active_connections.get(socket_id) if websocket: await websocket.send_bytes(binary_data) # Send binary data else: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Socket ID {socket_id} not connected") async def listen(self, socket_id:str, request_id:str) -> str: req = InRequest() # Create a Future for waiting for the response future = asyncio.get_event_loop().create_future() req.responses[request_id] = future self.in_request[socket_id] = req try: return await future # Await the future until it's set with a response except asyncio.CancelledError: raise HTTPException( status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Socket ID {socket_id} not connected or canceled") async def notify(self, socket_id: str, message: str): logging.debug(message) # If there is a pending future for this socket, set the result if socket_id in self.in_request: request_id, payload = self.extract_message(message) if request_id is not None: self.in_request[socket_id].responses[request_id].set_result(payload) self.in_request.pop(socket_id, None) def extract_message(self, message:str): request_id = None payload = None logging.debug(message) try: o = json.loads(message) if o is not None: request_id, payload = o.get('request_id'), o.get('payload') except Exception as e: logging.warning(f"extract_message error: {str(e)}") return request_id, payload