File size: 4,652 Bytes
d1eebad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import socketio
import asyncio


from apps.webui.models.users import Users
from utils.utils import decode_token

sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")

# Dictionary to maintain the user pool

SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3


@sio.event
async def connect(sid, environ, auth):
    user = None
    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:
            SESSION_POOL[sid] = user.id
            if user.id in USER_POOL:
                USER_POOL[user.id].append(sid)
            else:
                USER_POOL[user.id] = [sid]

            print(f"user {user.name}({user.id}) connected with session ID {sid}")

            await sio.emit("user-count", {"count": len(set(USER_POOL))})
            await sio.emit("usage", {"models": get_models_in_use()})


@sio.on("user-join")
async def user_join(sid, data):
    print("user-join", sid, data)

    auth = data["auth"] if "auth" in data else None

    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:

            SESSION_POOL[sid] = user.id
            if user.id in USER_POOL:
                USER_POOL[user.id].append(sid)
            else:
                USER_POOL[user.id] = [sid]

            print(f"user {user.name}({user.id}) connected with session ID {sid}")

            await sio.emit("user-count", {"count": len(set(USER_POOL))})


@sio.on("user-count")
async def user_count(sid):
    await sio.emit("user-count", {"count": len(set(USER_POOL))})


def get_models_in_use():
    # Aggregate all models in use
    models_in_use = []
    for model_id, data in USAGE_POOL.items():
        models_in_use.append(model_id)

    return models_in_use


@sio.on("usage")
async def usage(sid, data):

    model_id = data["model"]

    # Cancel previous callback if there is one
    if model_id in USAGE_POOL:
        USAGE_POOL[model_id]["callback"].cancel()

    # Store the new usage data and task

    if model_id in USAGE_POOL:
        USAGE_POOL[model_id]["sids"].append(sid)
        USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))

    else:
        USAGE_POOL[model_id] = {"sids": [sid]}

    # Schedule a task to remove the usage data after TIMEOUT_DURATION
    USAGE_POOL[model_id]["callback"] = asyncio.create_task(
        remove_after_timeout(sid, model_id)
    )

    # Broadcast the usage data to all clients
    await sio.emit("usage", {"models": get_models_in_use()})


async def remove_after_timeout(sid, model_id):
    try:
        await asyncio.sleep(TIMEOUT_DURATION)
        if model_id in USAGE_POOL:
            print(USAGE_POOL[model_id]["sids"])
            USAGE_POOL[model_id]["sids"].remove(sid)
            USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))

            if len(USAGE_POOL[model_id]["sids"]) == 0:
                del USAGE_POOL[model_id]

            # Broadcast the usage data to all clients
            await sio.emit("usage", {"models": get_models_in_use()})
    except asyncio.CancelledError:
        # Task was cancelled due to new 'usage' event
        pass


@sio.event
async def disconnect(sid):
    if sid in SESSION_POOL:
        user_id = SESSION_POOL[sid]
        del SESSION_POOL[sid]

        USER_POOL[user_id].remove(sid)

        if len(USER_POOL[user_id]) == 0:
            del USER_POOL[user_id]

        await sio.emit("user-count", {"count": len(USER_POOL)})
    else:
        print(f"Unknown session ID {sid} disconnected")


async def get_event_emitter(request_info):
    async def __event_emitter__(event_data):
        await sio.emit(
            "chat-events",
            {
                "chat_id": request_info["chat_id"],
                "message_id": request_info["message_id"],
                "data": event_data,
            },
            to=request_info["session_id"],
        )

    return __event_emitter__


async def get_event_call(request_info):
    async def __event_call__(event_data):
        response = await sio.call(
            "chat-events",
            {
                "chat_id": request_info["chat_id"],
                "message_id": request_info["message_id"],
                "data": event_data,
            },
            to=request_info["session_id"],
        )
        return response

    return __event_call__