import asyncio import logging from typing import Dict, Set from aiohttp import web, WSMsgType import json import time import datetime from api_core import VideoGenerationAPI logger = logging.getLogger(__name__) class UserSession: """ Represents a user's session with the API. Each WebSocket connection gets its own session with separate queues and rate limits. """ def __init__(self, user_id: str, user_role: str, ws: web.WebSocketResponse, shared_api): self.user_id = user_id self.user_role = user_role self.ws = ws self.shared_api = shared_api # For shared resources like endpoint manager # Create separate queues for this user session self.chat_queue = asyncio.Queue() self.video_queue = asyncio.Queue() self.search_queue = asyncio.Queue() self.simulation_queue = asyncio.Queue() # New queue for description evolution # Track request counts and rate limits self.request_counts = { 'chat': 0, 'video': 0, 'search': 0, 'simulation': 0 # New counter for simulation requests } # Last request timestamps for rate limiting self.last_request_times = { 'chat': time.time(), 'video': time.time(), 'search': time.time(), 'simulation': time.time() # New timestamp for simulation requests } # Session creation time self.created_at = time.time() self.background_tasks = [] async def start(self): """Start all the queue processors for this session""" # Start background tasks for handling different request types self.background_tasks = [ asyncio.create_task(self._process_chat_queue()), asyncio.create_task(self._process_video_queue()), asyncio.create_task(self._process_search_queue()), asyncio.create_task(self._process_simulation_queue()) # New worker for simulation requests ] logger.info(f"Started session for user {self.user_id} with role {self.user_role}") async def stop(self): """Stop all background tasks for this session""" for task in self.background_tasks: task.cancel() try: # Wait for tasks to complete cancellation await asyncio.gather(*self.background_tasks, return_exceptions=True) except asyncio.CancelledError: pass logger.info(f"Stopped session for user {self.user_id}") async def _process_chat_queue(self): """High priority queue for chat operations""" while True: data = await self.chat_queue.get() try: if data['action'] == 'join_chat': result = await self.shared_api.handle_join_chat(data, self.ws) elif data['action'] == 'chat_message': result = await self.shared_api.handle_chat_message(data, self.ws) elif data['action'] == 'leave_chat': result = await self.shared_api.handle_leave_chat(data, self.ws) # Redirect thumbnail requests to process_generic_request for consistent handling elif data['action'] == 'generate_video_thumbnail': # Pass to the generic request handler to maintain consistent logic await self.process_generic_request(data) # Skip normal response handling since process_generic_request already sends a response self.chat_queue.task_done() continue else: raise ValueError(f"Unknown chat action: {data['action']}") await self.ws.send_json(result) # Update metrics self.request_counts['chat'] += 1 self.last_request_times['chat'] = time.time() except Exception as e: logger.error(f"Error processing chat request for user {self.user_id}: {e}") try: await self.ws.send_json({ 'action': data['action'], 'requestId': data.get('requestId'), 'success': False, 'error': f'Chat error: {str(e)}' }) except Exception as send_error: logger.error(f"Error sending error response: {send_error}") finally: self.chat_queue.task_done() async def _process_video_queue(self): """Process multiple video generation requests in parallel for this user""" from api_config import VIDEO_ROUND_ROBIN_ENDPOINT_URLS active_tasks = set() # Set a per-user concurrent limit based on role max_concurrent = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS) if self.user_role == 'anon': max_concurrent = min(2, max_concurrent) # Limit anonymous users elif self.user_role == 'normal': max_concurrent = min(4, max_concurrent) # Standard users # Pro and admin can use all endpoints async def process_single_request(data): try: title = data.get('title', '') description = data.get('description', '') video_prompt_prefix = data.get('video_prompt_prefix', '') options = data.get('options', {}) # Pass the user role to generate_video video_data = await self.shared_api.generate_video( title, description, video_prompt_prefix, options, self.user_role ) result = { 'action': 'generate_video', 'requestId': data.get('requestId'), 'success': True, 'video': video_data, } await self.ws.send_json(result) # Update metrics self.request_counts['video'] += 1 self.last_request_times['video'] = time.time() except Exception as e: logger.error(f"Error processing video request for user {self.user_id}: {e}") try: await self.ws.send_json({ 'action': 'generate_video', 'requestId': data.get('requestId'), 'success': False, 'error': f'Video generation error: {str(e)}' }) except Exception as send_error: logger.error(f"Error sending error response: {send_error}") finally: active_tasks.discard(asyncio.current_task()) while True: # Clean up completed tasks active_tasks = {task for task in active_tasks if not task.done()} # Start new tasks if we have capacity while len(active_tasks) < max_concurrent: try: # Use try_get to avoid blocking if queue is empty data = await asyncio.wait_for(self.video_queue.get(), timeout=0.1) # Create and start new task task = asyncio.create_task(process_single_request(data)) active_tasks.add(task) except asyncio.TimeoutError: # No items in queue, break inner loop break except Exception as e: logger.error(f"Error creating video generation task for user {self.user_id}: {e}") break # Wait a short time before checking queue again await asyncio.sleep(0.1) # Handle any completed tasks' errors for task in list(active_tasks): if task.done(): try: await task except Exception as e: logger.error(f"Task failed with error for user {self.user_id}: {e}") active_tasks.discard(task) async def _process_search_queue(self): """Medium priority queue for search operations""" while True: try: data = await self.search_queue.get() request_id = data.get('requestId') query = data.get('query', '').strip() attempt_count = data.get('attemptCount', 0) # logger.info(f"Processing search request for user {self.user_id}, attempt={attempt_count}") if not query: logger.warning(f"Empty query received in request from user {self.user_id}: {data}") result = { 'action': 'search', 'requestId': request_id, 'success': False, 'error': 'No search query provided' } else: try: search_result = await self.shared_api.search_video( query, attempt_count=attempt_count ) if search_result: # logger.info(f"Search successful for user {self.user_id}, query '{query}'") result = { 'action': 'search', 'requestId': request_id, 'success': True, 'result': search_result } else: # logger.warning(f"No results found for user {self.user_id}, query '{query}'") result = { 'action': 'search', 'requestId': request_id, 'success': False, 'error': 'No results found' } except Exception as e: logger.error(f"Search error for user {self.user_id}, (attempt {attempt_count}): {str(e)}") result = { 'action': 'search', 'requestId': request_id, 'success': False, 'error': f'Search error: {str(e)}' } await self.ws.send_json(result) # Update metrics self.request_counts['search'] += 1 self.last_request_times['search'] = time.time() except Exception as e: logger.error(f"Error in search queue processor for user {self.user_id}: {str(e)}") try: error_response = { 'action': 'search', 'requestId': data.get('requestId') if 'data' in locals() else None, 'success': False, 'error': f'Internal server error: {str(e)}' } await self.ws.send_json(error_response) except Exception as send_error: logger.error(f"Error sending error response: {send_error}") finally: if 'search_queue' in self.__dict__: self.search_queue.task_done() async def _process_simulation_queue(self): """Dedicated queue for video simulation requests""" while True: try: data = await self.simulation_queue.get() request_id = data.get('requestId') # Extract parameters from the request video_id = data.get('video_id', '') original_title = data.get('original_title', '') original_description = data.get('original_description', '') current_description = data.get('current_description', '') condensed_history = data.get('condensed_history', '') evolution_count = data.get('evolution_count', 0) chat_messages = data.get('chat_messages', '') logger.info(f"Processing video simulation for user {self.user_id}, video_id={video_id}, evolution_count={evolution_count}") # Validate required parameters if not original_title or not original_description or not current_description: result = { 'action': 'simulate', 'requestId': request_id, 'success': False, 'error': 'Missing required parameters' } else: try: # Call the simulate method in the API simulation_result = await self.shared_api.simulate( original_title=original_title, original_description=original_description, current_description=current_description, condensed_history=condensed_history, evolution_count=evolution_count, chat_messages=chat_messages ) result = { 'action': 'simulate', 'requestId': request_id, 'success': True, 'evolved_description': simulation_result['evolved_description'], 'condensed_history': simulation_result['condensed_history'] } except Exception as e: logger.error(f"Error simulating video for user {self.user_id}, video_id={video_id}: {str(e)}") result = { 'action': 'simulate', 'requestId': request_id, 'success': False, 'error': f'Simulation error: {str(e)}' } await self.ws.send_json(result) # Update metrics self.request_counts['simulation'] += 1 self.last_request_times['simulation'] = time.time() except Exception as e: logger.error(f"Error in simulation queue processor for user {self.user_id}: {str(e)}") try: error_response = { 'action': 'simulate', 'requestId': data.get('requestId') if 'data' in locals() else None, 'success': False, 'error': f'Internal server error: {str(e)}' } await self.ws.send_json(error_response) except Exception as send_error: logger.error(f"Error sending error response: {send_error}") finally: if 'simulation_queue' in self.__dict__: self.simulation_queue.task_done() async def process_generic_request(self, data: dict) -> None: """Handle general requests that don't fit into specialized queues""" try: request_id = data.get('requestId') action = data.get('action') def error_response(message: str): return { 'action': action, 'requestId': request_id, 'success': False, 'error': message } if action == 'heartbeat': # Include user role info in heartbeat response await self.ws.send_json({ 'action': 'heartbeat', 'requestId': request_id, 'success': True, 'user_role': self.user_role }) elif action == 'get_user_role': # Return the user role information await self.ws.send_json({ 'action': 'get_user_role', 'requestId': request_id, 'success': True, 'user_role': self.user_role }) elif action == 'generate_caption': title = data.get('params', {}).get('title') description = data.get('params', {}).get('description') if not title or not description: await self.ws.send_json(error_response('Missing title or description')) return caption = await self.shared_api.generate_caption(title, description) await self.ws.send_json({ 'action': action, 'requestId': request_id, 'success': True, 'caption': caption }) # evolve_description is now handled by the dedicated simulation queue processor elif action == 'generate_video_thumbnail': title = data.get('title', '') or data.get('params', {}).get('title', '') description = data.get('description', '') or data.get('params', {}).get('description', '') video_prompt_prefix = data.get('video_prompt_prefix', '') or data.get('params', {}).get('video_prompt_prefix', '') options = data.get('options', {}) or data.get('params', {}).get('options', {}) if not title: await self.ws.send_json(error_response('Missing title for thumbnail generation')) return # Ensure the options include the thumbnail flag options['thumbnail'] = True # Prioritize thumbnail generation with higher priority options['priority'] = 'high' # Add small size settings if not already specified if 'width' not in options: options['width'] = 512 # Default thumbnail width if 'height' not in options: options['height'] = 288 # Default 16:9 aspect ratio if 'num_frames' not in options: options['num_frames'] = 25 # 1 second @ 25fps # Let the API know this is a thumbnail for a specific video options['video_id'] = data.get('video_id', f"thumbnail-{request_id}") logger.info(f"Generating thumbnail for video {options['video_id']} for user {self.user_id}") try: # Generate the thumbnail thumbnail_data = await self.shared_api.generate_video_thumbnail( title, description, video_prompt_prefix, options, self.user_role ) # Respond with appropriate format based on the parameter names used in the request if 'thumbnailUrl' in data or 'thumbnailUrl' in data.get('params', {}): # Legacy format using thumbnailUrl await self.ws.send_json({ 'action': action, 'requestId': request_id, 'success': True, 'thumbnailUrl': thumbnail_data or "", }) else: # New format using thumbnail await self.ws.send_json({ 'action': action, 'requestId': request_id, 'success': True, 'thumbnail': thumbnail_data, }) except Exception as e: logger.error(f"Error generating thumbnail: {str(e)}") await self.ws.send_json(error_response(f"Thumbnail generation failed: {str(e)}")) # Handle deprecated thumbnail actions elif action == 'generate_thumbnail' or action == 'old_generate_thumbnail': # Redirect to video thumbnail generation logger.warning(f"Deprecated thumbnail action '{action}' used, redirecting to generate_video_thumbnail") # Extract parameters title = data.get('title', '') or data.get('params', {}).get('title', '') description = data.get('description', '') or data.get('params', {}).get('description', '') if not title or not description: await self.ws.send_json(error_response('Missing title or description')) return # Create a new request with the correct action new_request = { 'action': 'generate_video_thumbnail', 'requestId': request_id, 'title': title, 'description': description, 'options': { 'width': 512, 'height': 288, 'thumbnail': True, 'video_id': f"thumbnail-{request_id}" } } # Process with the new action await self.process_generic_request(new_request) else: await self.ws.send_json(error_response(f'Unknown action: {action}')) except Exception as e: logger.error(f"Error processing generic request for user {self.user_id}: {str(e)}") try: await self.ws.send_json({ 'action': data.get('action'), 'requestId': data.get('requestId'), 'success': False, 'error': f'Internal server error: {str(e)}' }) except Exception as send_error: logger.error(f"Error sending error response: {send_error}") class SessionManager: """ Manages all active user sessions and shared resources. """ def __init__(self): self.sessions = {} self.shared_api = VideoGenerationAPI() # Single instance for shared resources self.session_lock = asyncio.Lock() async def create_session(self, user_id: str, user_role: str, ws: web.WebSocketResponse) -> UserSession: """Create a new user session""" async with self.session_lock: # Create a new session for this user session = UserSession(user_id, user_role, ws, self.shared_api) await session.start() self.sessions[user_id] = session return session async def delete_session(self, user_id: str) -> None: """Delete a user session and clean up resources""" async with self.session_lock: if user_id in self.sessions: session = self.sessions[user_id] await session.stop() del self.sessions[user_id] logger.info(f"Deleted session for user {user_id}") def get_session(self, user_id: str) -> UserSession: """Get a user session if it exists""" return self.sessions.get(user_id) async def close_all_sessions(self) -> None: """Close all active sessions (used during shutdown)""" async with self.session_lock: for user_id, session in list(self.sessions.items()): await session.stop() self.sessions.clear() logger.info("Closed all active sessions") @property def session_count(self) -> int: """Get the number of active sessions""" return len(self.sessions) def get_session_stats(self) -> Dict: """Get statistics about active sessions""" stats = { 'total_sessions': len(self.sessions), 'by_role': { 'anon': 0, 'normal': 0, 'pro': 0, 'admin': 0 }, 'requests': { 'chat': 0, 'video': 0, 'search': 0, 'simulation': 0 } } for session in self.sessions.values(): stats['by_role'][session.user_role] += 1 stats['requests']['chat'] += session.request_counts['chat'] stats['requests']['video'] += session.request_counts['video'] stats['requests']['search'] += session.request_counts['search'] stats['requests']['simulation'] += session.request_counts['simulation'] return stats