File size: 10,536 Bytes
488b360
 
 
9a8789a
488b360
 
2ab3299
100e61a
488b360
100e61a
488b360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100e61a
488b360
 
 
 
 
100e61a
488b360
 
 
 
 
 
 
100e61a
488b360
 
 
 
 
 
 
 
 
 
 
 
 
 
9a8789a
488b360
 
 
9a8789a
 
 
 
488b360
 
 
 
9a8789a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488b360
 
100e61a
488b360
 
 
9a8789a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488b360
 
9a8789a
 
 
488b360
100e61a
488b360
 
 
9a8789a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100e61a
488b360
 
9a8789a
 
 
100e61a
488b360
100e61a
488b360
 
 
9a8789a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100e61a
488b360
 
9a8789a
 
 
100e61a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
from uuid import UUID
import asyncio
from fastapi import WebSocket
from fastapi.websockets import WebSocketDisconnect
from starlette.websockets import WebSocketState
import logging
from typing import Any
from util import ParamsModel

Connections = dict[UUID, dict[str, WebSocket | asyncio.Queue]]


class ServerFullException(Exception):
    """Exception raised when the server is full."""

    pass


class ConnectionManager:
    def __init__(self):
        self.active_connections: Connections = {}

    async def connect(
        self, user_id: UUID, websocket: WebSocket, max_queue_size: int = 0
    ):
        await websocket.accept()
        user_count = self.get_user_count()
        print(f"User count: {user_count}")
        if max_queue_size > 0 and user_count >= max_queue_size:
            print("Server is full")
            await websocket.send_json({"status": "error", "message": "Server is full"})
            await websocket.close()
            raise ServerFullException("Server is full")
        print(f"New user connected: {user_id}")
        self.active_connections[user_id] = {
            "websocket": websocket,
            "queue": asyncio.Queue(),
        }
        await websocket.send_json(
            {"status": "connected", "message": "Connected"},
        )
        await websocket.send_json({"status": "wait"})
        await websocket.send_json({"status": "send_frame"})

    def check_user(self, user_id: UUID) -> bool:
        return user_id in self.active_connections

    async def update_data(self, user_id: UUID, new_data: ParamsModel):
        user_session = self.active_connections.get(user_id)
        if user_session:
            queue = user_session["queue"]
            await queue.put(new_data)

    async def get_latest_data(self, user_id: UUID) -> ParamsModel | None:
        user_session = self.active_connections.get(user_id)
        if user_session:
            queue = user_session["queue"]
            try:
                return await queue.get()
            except asyncio.QueueEmpty:
                return None
        return None

    def delete_user(self, user_id: UUID):
        user_session = self.active_connections.pop(user_id, None)
        if user_session:
            queue = user_session["queue"]
            while not queue.empty():
                try:
                    queue.get_nowait()
                except asyncio.QueueEmpty:
                    continue

    def get_user_count(self) -> int:
        return len(self.active_connections)

    def get_websocket(self, user_id: UUID) -> WebSocket | None:
        user_session = self.active_connections.get(user_id)
        if user_session:
            websocket = user_session["websocket"]
            # Both client_state and application_state should be checked
            # to ensure the websocket is fully connected and not closing
            if (websocket.client_state == WebSocketState.CONNECTED and 
                websocket.application_state == WebSocketState.CONNECTED):
                return user_session["websocket"]
        return None

    async def disconnect(self, user_id: UUID):
        # First check if user is in active connections
        if user_id not in self.active_connections:
            return
            
        # Get the websocket directly from active_connections to avoid get_websocket validation
        user_session = self.active_connections.get(user_id)
        if user_session and "websocket" in user_session:
            websocket = user_session["websocket"]
            try:
                # Only attempt close if not already closed
                if (websocket.client_state != WebSocketState.DISCONNECTED and 
                    websocket.application_state != WebSocketState.DISCONNECTED):
                    await websocket.close()
            except Exception as e:
                logging.error(f"Error closing websocket for {user_id}: {e}")
                
        # Always delete the user to ensure cleanup
        self.delete_user(user_id)

    async def send_json(self, user_id: UUID, data: dict):
        try:
            websocket = self.get_websocket(user_id)
            if websocket:
                try:
                    await websocket.send_json(data)
                except RuntimeError as e:
                    error_msg = str(e)
                    if any(err in error_msg for err in [
                        "WebSocket is not connected", 
                        "Cannot call \"send\" once a close message has been sent",
                        "Cannot call \"receive\" once a close message has been sent",
                        "WebSocket is disconnected"]):
                        # The websocket was disconnected or is closing
                        logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
                        await self.disconnect(user_id)
                    else:
                        logging.error(f"Runtime error in send_json: {e}")
        except WebSocketDisconnect as disconnect_error:
            # Handle websocket disconnection event
            code = disconnect_error.code
            if code == 1006:  # ABNORMAL_CLOSURE
                logging.info(f"WebSocket abnormally closed for user {user_id} during send: Connection was closed without a proper close handshake")
            else:
                logging.info(f"WebSocket disconnected for user {user_id} with code {code} during send: {disconnect_error.reason}")
            
            # Always disconnect the user
            if user_id in self.active_connections:
                await self.disconnect(user_id)
        except Exception as e:
            logging.error(f"Error: Send json: {e}")
            # If any send fails, ensure the user gets removed to prevent further errors
            if user_id in self.active_connections:
                await self.disconnect(user_id)

    async def receive_json(self, user_id: UUID) -> dict | None:
        try:
            websocket = self.get_websocket(user_id)
            if websocket:
                try:
                    # Receive the raw message and handle JSON parsing manually for better error handling
                    try:
                        data = await websocket.receive_json()
                        # Verify it's a dictionary
                        if not isinstance(data, dict):
                            logging.error(f"Expected dict but received {type(data)} from user {user_id}: {data}")
                            return None
                        return data
                    except ValueError as json_err:
                        # Specific handling for JSON parsing errors
                        logging.error(f"JSON parsing error for user {user_id}: {json_err}")
                        return None
                except RuntimeError as e:
                    error_msg = str(e)
                    if any(err in error_msg for err in [
                        "WebSocket is not connected", 
                        "Cannot call \"send\" once a close message has been sent",
                        "Cannot call \"receive\" once a close message has been sent",
                        "WebSocket is disconnected"]):
                        # The websocket was disconnected or closing
                        logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
                        await self.disconnect(user_id)
                    else:
                        logging.error(f"Runtime error in receive_json: {e}")
                    return None
            return None
        except WebSocketDisconnect as disconnect_error:
            # Handle websocket disconnection event (this is a clean, expected path)
            code = disconnect_error.code
            if code == 1006:  # ABNORMAL_CLOSURE
                logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
            else:
                logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
            
            # Always disconnect the user
            if user_id in self.active_connections:
                await self.disconnect(user_id)
            return None
        except Exception as e:
            logging.error(f"Error: Receive json: {e}")
            # Ensure disconnection on any exception
            if user_id in self.active_connections:
                await self.disconnect(user_id)
            return None

    async def receive_bytes(self, user_id: UUID) -> bytes | None:
        try:
            websocket = self.get_websocket(user_id)
            if websocket:
                try:
                    return await websocket.receive_bytes()
                except RuntimeError as e:
                    error_msg = str(e)
                    if any(err in error_msg for err in [
                        "WebSocket is not connected", 
                        "Cannot call \"send\" once a close message has been sent",
                        "Cannot call \"receive\" once a close message has been sent",
                        "WebSocket is disconnected"]):
                        # The websocket was disconnected or closing
                        logging.info(f"WebSocket disconnected/closing for user {user_id}: {error_msg}")
                        await self.disconnect(user_id)
                    else:
                        logging.error(f"Runtime error in receive_bytes: {e}")
                    return None
            return None
        except WebSocketDisconnect as disconnect_error:
            # Handle websocket disconnection event (this is a clean, expected path)
            code = disconnect_error.code
            if code == 1006:  # ABNORMAL_CLOSURE
                logging.info(f"WebSocket abnormally closed for user {user_id}: Connection was closed without a proper close handshake")
            else:
                logging.info(f"WebSocket disconnected for user {user_id} with code {code}: {disconnect_error.reason}")
            
            # Always disconnect the user
            if user_id in self.active_connections:
                await self.disconnect(user_id)
            return None
        except Exception as e:
            logging.error(f"Error: Receive bytes: {e}")
            # Ensure disconnection on any exception
            if user_id in self.active_connections:
                await self.disconnect(user_id)
            return None