File size: 4,652 Bytes
d1eebad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import socketio
import asyncio
from apps.webui.models.users import Users
from utils.utils import decode_token
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
@sio.event
async def connect(sid, environ, auth):
user = None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
@sio.on("user-join")
async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(set(USER_POOL))})
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
# Store the new usage data and task
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
else:
USAGE_POOL[model_id] = {"sids": [sid]}
# Schedule a task to remove the usage data after TIMEOUT_DURATION
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
print(f"Unknown session ID {sid} disconnected")
async def get_event_emitter(request_info):
async def __event_emitter__(event_data):
await sio.emit(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=request_info["session_id"],
)
return __event_emitter__
async def get_event_call(request_info):
async def __event_call__(event_data):
response = await sio.call(
"chat-events",
{
"chat_id": request_info["chat_id"],
"message_id": request_info["message_id"],
"data": event_data,
},
to=request_info["session_id"],
)
return response
return __event_call__
|