Spaces:
Build error
Build error
import enum | |
import json | |
from datetime import datetime | |
from typing import Dict, List, Optional | |
from fastapi import WebSocket, websockets | |
from fastapi.encoders import jsonable_encoder | |
from pydantic import BaseModel | |
from core.db.redis_session import redis_chat_client | |
from core.security import get_uid_hash | |
from models import User | |
class ChatMessageTypes(enum.Enum): | |
MESSAGE_HISTORY: int = 1 | |
PUBLIC_MESSAGE: int = 2 | |
ANON_MESSAGE: int = 3 | |
USER_JOINED: int = 4 | |
USER_LEFT: int = 5 | |
ACTIVE_USER_LIST: int = 6 | |
class Message(BaseModel): | |
msg_type: int | |
data: Optional[str] | |
user: Optional[str] | |
time: datetime | |
class WebSocketManager: | |
def __init__(self): | |
self.connections: Dict = {} | |
async def update(self, data, key): | |
msg = await redis_chat_client.client.get(key) | |
if msg: | |
msg = json.loads(msg) | |
else: | |
msg = [] | |
msg.append(data) | |
await redis_chat_client.client.set( | |
key, json.dumps(msg, separators=(",", ":")), expire=60 * 60 * 1000 | |
) | |
async def send_history(self, websocket: WebSocket, class_session_id: int): | |
chat_history = await redis_chat_client.client.get( | |
f"chat_class_sess_{class_session_id}", encoding="UTF-8" | |
) | |
msg_history_instance = Message( | |
msg_type=ChatMessageTypes.MESSAGE_HISTORY.value, | |
data=chat_history, | |
time=datetime.utcnow(), | |
) | |
await websocket.send_json( | |
jsonable_encoder(msg_history_instance.dict(exclude_none=True)) | |
) | |
async def connect(self, websocket: WebSocket, user_id: int, class_session_id: int): | |
await websocket.accept() | |
try: | |
self.connections[class_session_id].append(websocket) | |
except: | |
self.connections.update({class_session_id: [websocket]}) | |
msg_instance = Message( | |
msg_type=ChatMessageTypes.USER_JOINED.value, | |
time=datetime.utcnow(), | |
user=user_id, | |
) | |
# self.send_history(websocket=websocket, class_session_id=class_session_id) | |
await self.broadcast( | |
msg_instance.dict(exclude_none=True), user_id, class_session_id, save=False | |
) | |
pre_status = await redis_chat_client.client.get( | |
f"active_status_{class_session_id}", encoding="UTF-8" | |
) | |
active_user_instance = Message( | |
msg_type=ChatMessageTypes.ACTIVE_USER_LIST.value, | |
data=pre_status, | |
time=datetime.utcnow(), | |
) | |
# print(active_user_instance.dict(exclude_none=True)) | |
await websocket.send_json( | |
jsonable_encoder(active_user_instance.dict(exclude_none=True)) | |
) | |
if not pre_status: | |
pre_status_obj = [] | |
else: | |
pre_status_obj = json.loads(pre_status) | |
pre_status_obj.append(user_id) | |
pre_status_obj = list(set(pre_status_obj)) | |
await redis_chat_client.client.set( | |
f"active_status_{class_session_id}", | |
json.dumps(pre_status_obj, separators=(",", ":")), | |
) | |
await redis_chat_client.client.expire( | |
f"active_status_{class_session_id}", | |
60 * 60 * 1000, | |
) | |
async def disconnect( | |
self, websocket: WebSocket, user_id: int, class_session_id: int | |
): | |
self.connections[class_session_id].remove(websocket) | |
msg_instance = Message( | |
msg_type=ChatMessageTypes.USER_LEFT.value, | |
time=datetime.utcnow(), | |
user=user_id, | |
) | |
await self.broadcast(msg_instance, user_id, class_session_id, save=False) | |
pre_status = json.loads( | |
await redis_chat_client.client.get(f"active_status_{class_session_id}") | |
) | |
pre_status.remove(user_id) | |
await redis_chat_client.client.set( | |
f"active_status_{class_session_id}", | |
json.dumps(pre_status, separators=(",", ":")), | |
) | |
await redis_chat_client.client.expire( | |
f"active_status_{class_session_id}", | |
60 * 60 * 1000, | |
) | |
async def broadcast( | |
self, data: any, user_id: int, class_session_id: int, save: bool = True | |
): | |
encoded_data = jsonable_encoder(data) | |
for connection in self.connections.get(class_session_id): | |
try: | |
await connection.send_json(encoded_data) | |
except Exception as e: | |
pass | |
if save: | |
await self.update(encoded_data, f"chat_class_sess_{class_session_id}") | |
async def message( | |
self, | |
websocket: WebSocket, | |
message: str, | |
user_id: int, | |
class_session_id: int, | |
anon: bool = False, | |
): | |
msg_type = ChatMessageTypes.PUBLIC_MESSAGE.value | |
user = user_id | |
if anon: | |
msg_type = ChatMessageTypes.ANON_MESSAGE.value | |
user = get_uid_hash(user_id) | |
msg_instance = Message( | |
msg_type=msg_type, | |
data=message, | |
user=user, | |
time=datetime.utcnow(), | |
) | |
await self.broadcast( | |
msg_instance.dict(exclude_none=True), user_id, class_session_id | |
) | |
ws = WebSocketManager() | |