Last commit not found
from fastapi import WebSocket, WebSocketDisconnect, HTTPException | |
from typing import Dict | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import AIMessage | |
from jose import JWTError, jwt | |
import json | |
from .auth import SECRET_KEY, ALGORITHM | |
from .db.database import get_user_by_username | |
class ConnectionManager: | |
def __init__(self): | |
self.active_connections: Dict[str, WebSocket] = {} | |
self.llm = ChatOpenAI(model="gpt-4o-mini") | |
self.chains = {} | |
async def connect(self, websocket: WebSocket, username: str): | |
# Remove the websocket.accept() from here since it's called in handle_websocket | |
self.active_connections[username] = websocket | |
self.chains[username] = self.llm | |
# Send confirmation of successful connection | |
await websocket.send_json({ | |
"type": "connection_established", | |
"message": f"Connected as {username}" | |
}) | |
def disconnect(self, username: str): | |
self.active_connections.pop(username, None) | |
self.chains[username] = None | |
async def send_message(self, message: str, username: str): | |
if username in self.active_connections: | |
websocket = self.active_connections[username] | |
try: | |
chain = self.chains[username] | |
astream = chain.astream(message) | |
async for chunk in astream: | |
if isinstance(chunk, AIMessage): | |
await websocket.send_json({ | |
"type": "message", | |
"message": chunk.content, | |
"sender": "ai" | |
}) | |
except Exception as e: | |
await websocket.send_json({ | |
"type": "error", | |
"message": str(e) | |
}) | |
manager = ConnectionManager() | |
async def handle_websocket(websocket: WebSocket): | |
await websocket.accept() # Accept the connection once | |
username = None | |
try: | |
# Wait for authentication message | |
auth_message = await websocket.receive_text() | |
try: | |
# Try to parse as JSON first | |
try: | |
data = json.loads(auth_message) | |
token = data.get('token') | |
except json.JSONDecodeError: | |
# If not JSON, treat as raw token | |
token = auth_message | |
# Verify token | |
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
username = payload.get("sub") | |
if not username: | |
await websocket.close(code=1008) | |
return | |
# Get user from database | |
user = await get_user_by_username(username) | |
if not user: | |
await websocket.close(code=1008) | |
return | |
# Connect user | |
await manager.connect(websocket, username) | |
# Main message loop | |
while True: | |
message = await websocket.receive_text() | |
try: | |
data = json.loads(message) | |
if data.get('type') == 'message': | |
await manager.send_message(data.get('content', ''), username) | |
except json.JSONDecodeError: | |
# Handle plain text messages | |
await manager.send_message(message, username) | |
except JWTError: | |
await websocket.send_json({ | |
"type": "error", | |
"message": "Authentication failed" | |
}) | |
await websocket.close(code=1008) | |
except WebSocketDisconnect: | |
if username: | |
manager.disconnect(username) | |
except Exception as e: | |
print(f"WebSocket error: {str(e)}") | |
if username: | |
manager.disconnect(username) | |
try: | |
await websocket.close(code=1011) | |
except: | |
pass |