import gradio as gr import asyncio import websockets import json import uuid import argparse import urllib.parse from datetime import datetime import logging import sys # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler(sys.stdout)] ) logger = logging.getLogger("chat-node") # Dictionary to store active connections active_connections = {} # Dictionary to store message history for each chat room (in-memory cache) chat_history = {} # Dictionary to track file modification times file_modification_times = {} # Dictionary to track users in each room/sector sector_users = {} # Grid dimensions for 2D sector map GRID_WIDTH = 10 GRID_HEIGHT = 10 # Directory to store persistent chat history HISTORY_DIR = "chat_history" import os import shutil from pathlib import Path import time # Create history directory if it doesn't exist os.makedirs(HISTORY_DIR, exist_ok=True) # README.md file that won't be listed or deleted README_PATH = os.path.join(HISTORY_DIR, "README.md") if not os.path.exists(README_PATH): with open(README_PATH, "w") as f: f.write("# Chat History\n\nThis directory contains persistent chat history files.\n") # Get node name from URL or command line def get_node_name(): parser = argparse.ArgumentParser(description='Start a chat node with a specific name') parser.add_argument('--node-name', type=str, default=None, help='Name for this chat node') parser.add_argument('--port', type=int, default=7860, help='Port to run the Gradio interface on') args = parser.parse_args() node_name = args.node_name port = args.port # If no node name specified, generate a random one if not node_name: node_name = f"node-{uuid.uuid4().hex[:8]}" return node_name, port def get_room_history_file(room_id): """Get the filename for a room's history.""" # Create timestamp-based log files timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") return os.path.join(HISTORY_DIR, f"{room_id}_{timestamp}.jsonl") def get_all_room_history_files(room_id): """Get all history files for a specific room.""" files = [] for file in os.listdir(HISTORY_DIR): if file.startswith(f"{room_id}_") and file.endswith(".jsonl"): files.append(os.path.join(HISTORY_DIR, file)) # Sort by modification time (newest first) files.sort(key=lambda x: os.path.getmtime(x), reverse=True) return files def load_room_history(room_id): """Load chat history for a room from all persistent storage files.""" if room_id not in chat_history: chat_history[room_id] = [] # Get all history files for this room history_files = get_all_room_history_files(room_id) # Track file modification times for file in history_files: if file not in file_modification_times: file_modification_times[file] = os.path.getmtime(file) # Load messages from all files messages = [] for history_file in history_files: try: with open(history_file, 'r') as f: for line in f: line = line.strip() if line: # Skip empty lines try: data = json.loads(line) messages.append(data) except json.JSONDecodeError: logger.error(f"Error parsing JSON line in {history_file}") except Exception as e: logger.error(f"Error loading history from {history_file}: {e}") # Sort by timestamp messages.sort(key=lambda x: x.get("timestamp", ""), reverse=False) chat_history[room_id] = messages logger.info(f"Loaded {len(messages)} messages from {len(history_files)} files for room {room_id}") # Track users in this sector if room_id not in sector_users: sector_users[room_id] = set() return chat_history[room_id] def save_message_to_history(room_id, message): """Save a single message to the newest history file for a room.""" # Get the newest history file or create a new one history_files = get_all_room_history_files(room_id) if not history_files: # Create a new file history_file = get_room_history_file(room_id) else: # Use the newest file if it's less than 1 MB, otherwise create a new one newest_file = history_files[0] if os.path.getsize(newest_file) > 1024 * 1024: # 1 MB history_file = get_room_history_file(room_id) else: history_file = newest_file try: # Append the message as a single line of JSON with open(history_file, 'a') as f: f.write(json.dumps(message) + '\n') # Update modification time file_modification_times[history_file] = os.path.getmtime(history_file) logger.debug(f"Saved message to {history_file}") except Exception as e: logger.error(f"Error saving message to {history_file}: {e}") def check_for_new_messages(): """Check for new messages in all history files.""" updated_rooms = set() # Check all files in the history directory for file in os.listdir(HISTORY_DIR): if file.endswith(".jsonl"): file_path = os.path.join(HISTORY_DIR, file) current_mtime = os.path.getmtime(file_path) # Check if this file is new or has been modified if file_path not in file_modification_times or current_mtime > file_modification_times[file_path]: # Extract room_id from filename parts = file.split('_', 1) if len(parts) > 0: room_id = parts[0] updated_rooms.add(room_id) # Update tracked modification time file_modification_times[file_path] = current_mtime # Reload history for updated rooms for room_id in updated_rooms: if room_id in chat_history: # Remember we had this room loaded old_history_len = len(chat_history[room_id]) # Clear and reload chat_history[room_id] = [] load_room_history(room_id) new_history_len = len(chat_history[room_id]) if new_history_len > old_history_len: logger.info(f"Found {new_history_len - old_history_len} new messages for room {room_id}") return updated_rooms def get_sector_coordinates(room_id): """Convert a room ID to grid coordinates, or assign new ones.""" try: # Try to parse room ID as "x,y" if ',' in room_id: x, y = map(int, room_id.split(',')) return max(0, min(x, GRID_WIDTH-1)), max(0, min(y, GRID_HEIGHT-1)) except: pass # Hash the room_id string to get stable coordinates hash_val = hash(room_id) x = abs(hash_val) % GRID_WIDTH y = abs(hash_val >> 8) % GRID_HEIGHT return x, y def generate_sector_map(): """Generate an ASCII representation of the sector map.""" # Initialize empty grid grid = [[' ' for _ in range(GRID_WIDTH)] for _ in range(GRID_HEIGHT)] # Place active rooms with user counts for room_id, users in sector_users.items(): if users: # Only show rooms with users x, y = get_sector_coordinates(room_id) user_count = len(users) grid[y][x] = str(min(user_count, 9)) if user_count < 10 else '+' # Create ASCII representation header = ' ' + ''.join([str(i % 10) for i in range(GRID_WIDTH)]) map_str = header + '\n' for y in range(GRID_HEIGHT): row = f"{y % 10}|" for x in range(GRID_WIDTH): row += grid[y][x] row += '|' map_str += row + '\n' footer = ' ' + ''.join([str(i % 10) for i in range(GRID_WIDTH)]) map_str += footer return f"```\n{map_str}\n```\n\nLegend: Number indicates users in sector. '+' means 10+ users." async def clear_all_history(): """Clear all chat history for all rooms.""" global chat_history # Clear in-memory history chat_history = {} # Delete all history files except README.md for file in os.listdir(HISTORY_DIR): if file.endswith(".md") and file != "README.md": os.remove(os.path.join(HISTORY_DIR, file)) # Broadcast clear message to all rooms clear_msg = { "type": "system", "content": "🧹 All chat history has been cleared by a user", "timestamp": datetime.now().isoformat(), "sender": "system" } for room_id in list(active_connections.keys()): clear_msg["room_id"] = room_id await broadcast_message(clear_msg, room_id) logger.info("All chat history cleared") return "All chat history cleared" async def websocket_handler(websocket, path): """Handle WebSocket connections.""" try: # Extract room_id from path if present path_parts = path.strip('/').split('/') room_id = path_parts[0] if path_parts else "default" # Register the new client client_id = str(uuid.uuid4()) if room_id not in active_connections: active_connections[room_id] = {} active_connections[room_id][client_id] = websocket # Add user to sector map if room_id not in sector_users: sector_users[room_id] = set() sector_users[room_id].add(client_id) # Get sector coordinates x, y = get_sector_coordinates(room_id) # Load or initialize chat history room_history = load_room_history(room_id) # Send welcome message welcome_msg = { "type": "system", "content": f"Welcome to room '{room_id}' (Sector {x},{y})! Connected from node '{NODE_NAME}'", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(welcome_msg)) # Send sector map map_msg = { "type": "system", "content": f"Sector Map:\n{generate_sector_map()}", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(map_msg)) # Send chat history for msg in room_history: await websocket.send(json.dumps(msg)) # Broadcast join notification join_msg = { "type": "system", "content": f"User joined the room (Sector {x},{y}) - {len(sector_users[room_id])} users now present", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await broadcast_message(join_msg, room_id) save_message_to_history(room_id, join_msg) logger.info(f"New client {client_id} connected to room {room_id} (Sector {x},{y})") # Handle messages from this client async for message in websocket: try: data = json.loads(message) # Check for clear command if data.get("type") == "command" and data.get("command") == "clear_history": result = await clear_all_history() continue # Check for map request if data.get("type") == "command" and data.get("command") == "show_map": map_msg = { "type": "system", "content": f"Sector Map:\n{generate_sector_map()}", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(map_msg)) continue # Add metadata to the message data["timestamp"] = datetime.now().isoformat() data["sender_node"] = NODE_NAME data["room_id"] = room_id # Store in history chat_history[room_id].append(data) if len(chat_history[room_id]) > 500: # Increased limit to 500 messages chat_history[room_id] = chat_history[room_id][-500:] # Save to persistent storage save_message_to_history(room_id, data) # Broadcast to all clients in the room await broadcast_message(data, room_id) except json.JSONDecodeError: error_msg = { "type": "error", "content": "Invalid JSON format", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await websocket.send(json.dumps(error_msg)) except websockets.exceptions.ConnectionClosed: logger.info(f"Client {client_id} disconnected from room {room_id}") finally: # Remove the client when disconnected if room_id in active_connections and client_id in active_connections[room_id]: del active_connections[room_id][client_id] # Remove user from sector map if room_id in sector_users and client_id in sector_users[room_id]: sector_users[room_id].remove(client_id) # Get sector coordinates x, y = get_sector_coordinates(room_id) # Broadcast leave notification leave_msg = { "type": "system", "content": f"User left the room (Sector {x},{y}) - {len(sector_users.get(room_id, set()))} users remaining", "timestamp": datetime.now().isoformat(), "sender": "system", "room_id": room_id } await broadcast_message(leave_msg, room_id) save_message_to_history(room_id, leave_msg) # Clean up empty rooms (but keep history) if not active_connections[room_id]: del active_connections[room_id] async def broadcast_message(message, room_id): """Broadcast a message to all clients in a room.""" if room_id in active_connections: disconnected_clients = [] for client_id, websocket in active_connections[room_id].items(): try: await websocket.send(json.dumps(message)) except websockets.exceptions.ConnectionClosed: disconnected_clients.append(client_id) # Clean up disconnected clients for client_id in disconnected_clients: del active_connections[room_id][client_id] async def start_websocket_server(host='0.0.0.0', port=8765): """Start the WebSocket server.""" server = await websockets.serve(websocket_handler, host, port) logger.info(f"WebSocket server started on ws://{host}:{port}") return server # Global variables for event loop and queue main_event_loop = None message_queue = [] def send_message(message, username, room_id): """Function to send a message from the Gradio interface.""" if not message.strip(): return None global message_queue msg_data = { "type": "chat", "content": message, "username": username, "room_id": room_id } # Add to queue for processing by the main loop message_queue.append(msg_data) # Format the message for display in the UI formatted_msg = f"{username}: {message}" return formatted_msg def join_room(room_id, chat_history_output): """Join a specific chat room.""" if not room_id.strip(): return "Please enter a valid room ID", chat_history_output # Sanitize the room ID room_id = urllib.parse.quote(room_id.strip()) # Load room history from persistent storage history = load_room_history(room_id) # Format existing messages formatted_history = [] for msg in history: if msg.get("type") == "chat": sender_node = f" [{msg.get('sender_node', 'unknown')}]" if "sender_node" in msg else "" time_str = "" if "timestamp" in msg: try: dt = datetime.fromisoformat(msg["timestamp"]) time_str = f"[{dt.strftime('%H:%M:%S')}] " except: pass formatted_history.append(f"{time_str}{msg.get('username', 'Anonymous')}{sender_node}: {msg.get('content', '')}") elif msg.get("type") == "system": formatted_history.append(f"System: {msg.get('content', '')}") return f"Joined room: {room_id}", formatted_history def send_clear_command(): """Send a command to clear all chat history.""" global message_queue msg_data = { "type": "command", "command": "clear_history", "username": "System" } # Add to queue for processing by the main loop message_queue.append(msg_data) return "🧹 Clearing all chat history..." def list_available_rooms(): """List all available chat rooms with their last activity time.""" history_files = get_all_history_files() if not history_files: return "No chat rooms available yet. Create one by joining a room!" room_list = "### Available Chat Rooms\n\n" for room_id, file_path, mod_time in history_files: last_activity = datetime.fromtimestamp(mod_time).strftime("%Y-%m-%d %H:%M:%S") room_list += f"- **{room_id}**: Last activity {last_activity}\n" return room_list def create_gradio_interface(): """Create and return the Gradio interface.""" with gr.Blocks(title=f"Chat Node: {NODE_NAME}") as interface: gr.Markdown(f"# Chat Node: {NODE_NAME}") gr.Markdown("Join a room by entering a room ID below or create a new one.") # Room list and management with gr.Row(): with gr.Column(scale=3): room_list = gr.Markdown(value="Loading available rooms...") refresh_button = gr.Button("🔄 Refresh Room List") with gr.Column(scale=1): clear_button = gr.Button("🧹 Clear All Chat History", variant="stop") # Join room controls with 2D grid input with gr.Row(): with gr.Column(scale=2): room_id_input = gr.Textbox(label="Room ID", placeholder="Enter room ID or use x,y coordinates") join_button = gr.Button("Join Room") with gr.Column(scale=1): with gr.Row(): x_coord = gr.Number(label="X", value=0, minimum=0, maximum=GRID_WIDTH-1, step=1) y_coord = gr.Number(label="Y", value=0, minimum=0, maximum=GRID_HEIGHT-1, step=1) grid_join_button = gr.Button("Join by Coordinates") # Chat area with multiline support chat_history_output = gr.Textbox(label="Chat History", lines=20, max_lines=20) # Message controls with multiline support with gr.Row(): username_input = gr.Textbox(label="Username", placeholder="Enter your username", value="User") with gr.Column(scale=3): message_input = gr.Textbox( label="Message", placeholder="Type your message here. Press Shift+Enter for new line, Enter to send.", lines=3 ) with gr.Column(scale=1): send_button = gr.Button("Send") map_button = gr.Button("🗺️ Show Map") # Current room display current_room_display = gr.Textbox(label="Current Room", value="Not joined any room yet") # Event handlers refresh_button.click( list_available_rooms, inputs=[], outputs=[room_list] ) clear_button.click( send_clear_command, inputs=[], outputs=[room_list] ) def join_by_coordinates(x, y): """Join a room using grid coordinates.""" room_id = f"{int(x)},{int(y)}" return room_id # Link grid coordinates to room ID grid_join_button.click( join_by_coordinates, inputs=[x_coord, y_coord], outputs=[room_id_input] ).then( join_room, inputs=[room_id_input, chat_history_output], outputs=[current_room_display, chat_history_output] ) join_button.click( join_room, inputs=[room_id_input, chat_history_output], outputs=[current_room_display, chat_history_output] ) def send_and_clear(message, username, room_id): if not room_id.startswith("Joined room:"): return "Please join a room first", message actual_room_id = room_id.replace("Joined room: ", "").strip() # Support for multi-line messages message_lines = message.strip().split("\n") formatted_msg = "" for line in message_lines: if line.strip(): # Skip empty lines sent_msg = send_message(line.strip(), username, actual_room_id) if sent_msg: formatted_msg += sent_msg + "\n" if formatted_msg: return "", formatted_msg return message, None send_button.click( send_and_clear, inputs=[message_input, username_input, current_room_display], outputs=[message_input, chat_history_output] ) def show_sector_map(room_id): if not room_id.startswith("Joined room:"): return "Please join a room first to view the map" return generate_sector_map() map_button.click( show_sector_map, inputs=[current_room_display], outputs=[chat_history_output] ) # Handle Enter key for sending, Shift+Enter for new line def on_message_submit(message, username, room_id): # Simply call send_and_clear return send_and_clear(message, username, room_id) message_input.submit( on_message_submit, inputs=[message_input, username_input, current_room_display], outputs=[message_input, chat_history_output] ) # On load, populate room list interface.load( list_available_rooms, inputs=[], outputs=[room_list] ) return interface async def process_message_queue(): """Process messages in the queue and broadcast them.""" global message_queue while True: # Check if there are messages to process if message_queue: # Get the oldest message msg_data = message_queue.pop(0) # Broadcast it await broadcast_message(msg_data, msg_data["room_id"]) # Sleep to avoid busy-waiting await asyncio.sleep(0.1) async def main(): """Main function to start the application.""" global NODE_NAME, main_event_loop NODE_NAME, port = get_node_name() # Store the main event loop for later use main_event_loop = asyncio.get_running_loop() # Start WebSocket server server = await start_websocket_server() # Start message queue processor asyncio.create_task(process_message_queue()) # Create and launch Gradio interface interface = create_gradio_interface() # Custom middleware to extract node name from URL query parameters from starlette.middleware.base import BaseHTTPMiddleware class NodeNameMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): global NODE_NAME query_params = dict(request.query_params) if "node_name" in query_params: NODE_NAME = query_params["node_name"] logger.info(f"Node name set to {NODE_NAME} from URL parameter") response = await call_next(request) return response # Apply middleware app = gr.routes.App.create_app(interface) app.add_middleware(NodeNameMiddleware) # Launch with the modified app gr.routes.mount_gradio_app(app, interface, path="/") # Run the FastAPI app with uvicorn import uvicorn config = uvicorn.Config(app, host="0.0.0.0", port=port) server = uvicorn.Server(config) logger.info(f"Starting Gradio interface on http://0.0.0.0:{port} with node name '{NODE_NAME}'") # Start message processor logger.info("Starting message queue processor") # Run the server and keep it running await server.serve() if __name__ == "__main__": asyncio.run(main())