aitube2 / api_session.py
jbilcke-hf's picture
jbilcke-hf HF Staff
small prompt fix
4b590f9
import asyncio
import logging
from typing import Dict, Set
from aiohttp import web, WSMsgType
import json
import time
import datetime
from api_core import VideoGenerationAPI
logger = logging.getLogger(__name__)
class UserSession:
"""
Represents a user's session with the API.
Each WebSocket connection gets its own session with separate queues and rate limits.
"""
def __init__(self, user_id: str, user_role: str, ws: web.WebSocketResponse, shared_api):
self.user_id = user_id
self.user_role = user_role
self.ws = ws
self.shared_api = shared_api # For shared resources like endpoint manager
# Create separate queues for this user session
self.chat_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.search_queue = asyncio.Queue()
self.simulation_queue = asyncio.Queue() # New queue for description evolution
# Track request counts and rate limits
self.request_counts = {
'chat': 0,
'video': 0,
'search': 0,
'simulation': 0 # New counter for simulation requests
}
# Last request timestamps for rate limiting
self.last_request_times = {
'chat': time.time(),
'video': time.time(),
'search': time.time(),
'simulation': time.time() # New timestamp for simulation requests
}
# Session creation time
self.created_at = time.time()
self.background_tasks = []
async def start(self):
"""Start all the queue processors for this session"""
# Start background tasks for handling different request types
self.background_tasks = [
asyncio.create_task(self._process_chat_queue()),
asyncio.create_task(self._process_video_queue()),
asyncio.create_task(self._process_search_queue()),
asyncio.create_task(self._process_simulation_queue()) # New worker for simulation requests
]
logger.info(f"Started session for user {self.user_id} with role {self.user_role}")
async def stop(self):
"""Stop all background tasks for this session"""
for task in self.background_tasks:
task.cancel()
try:
# Wait for tasks to complete cancellation
await asyncio.gather(*self.background_tasks, return_exceptions=True)
except asyncio.CancelledError:
pass
logger.info(f"Stopped session for user {self.user_id}")
async def _process_chat_queue(self):
"""High priority queue for chat operations"""
while True:
data = await self.chat_queue.get()
try:
if data['action'] == 'join_chat':
result = await self.shared_api.handle_join_chat(data, self.ws)
elif data['action'] == 'chat_message':
result = await self.shared_api.handle_chat_message(data, self.ws)
elif data['action'] == 'leave_chat':
result = await self.shared_api.handle_leave_chat(data, self.ws)
# Redirect thumbnail requests to process_generic_request for consistent handling
elif data['action'] == 'generate_video_thumbnail':
# Pass to the generic request handler to maintain consistent logic
await self.process_generic_request(data)
# Skip normal response handling since process_generic_request already sends a response
self.chat_queue.task_done()
continue
else:
raise ValueError(f"Unknown chat action: {data['action']}")
await self.ws.send_json(result)
# Update metrics
self.request_counts['chat'] += 1
self.last_request_times['chat'] = time.time()
except Exception as e:
logger.error(f"Error processing chat request for user {self.user_id}: {e}")
try:
await self.ws.send_json({
'action': data['action'],
'requestId': data.get('requestId'),
'success': False,
'error': f'Chat error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
self.chat_queue.task_done()
async def _process_video_queue(self):
"""Process multiple video generation requests in parallel for this user"""
from api_config import VIDEO_ROUND_ROBIN_ENDPOINT_URLS
active_tasks = set()
# Set a per-user concurrent limit based on role
max_concurrent = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS)
if self.user_role == 'anon':
max_concurrent = min(2, max_concurrent) # Limit anonymous users
elif self.user_role == 'normal':
max_concurrent = min(4, max_concurrent) # Standard users
# Pro and admin can use all endpoints
async def process_single_request(data):
try:
title = data.get('title', '')
description = data.get('description', '')
video_prompt_prefix = data.get('video_prompt_prefix', '')
options = data.get('options', {})
# Pass the user role to generate_video
video_data = await self.shared_api.generate_video(
title, description, video_prompt_prefix, options, self.user_role
)
result = {
'action': 'generate_video',
'requestId': data.get('requestId'),
'success': True,
'video': video_data,
}
await self.ws.send_json(result)
# Update metrics
self.request_counts['video'] += 1
self.last_request_times['video'] = time.time()
except Exception as e:
logger.error(f"Error processing video request for user {self.user_id}: {e}")
try:
await self.ws.send_json({
'action': 'generate_video',
'requestId': data.get('requestId'),
'success': False,
'error': f'Video generation error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
active_tasks.discard(asyncio.current_task())
while True:
# Clean up completed tasks
active_tasks = {task for task in active_tasks if not task.done()}
# Start new tasks if we have capacity
while len(active_tasks) < max_concurrent:
try:
# Use try_get to avoid blocking if queue is empty
data = await asyncio.wait_for(self.video_queue.get(), timeout=0.1)
# Create and start new task
task = asyncio.create_task(process_single_request(data))
active_tasks.add(task)
except asyncio.TimeoutError:
# No items in queue, break inner loop
break
except Exception as e:
logger.error(f"Error creating video generation task for user {self.user_id}: {e}")
break
# Wait a short time before checking queue again
await asyncio.sleep(0.1)
# Handle any completed tasks' errors
for task in list(active_tasks):
if task.done():
try:
await task
except Exception as e:
logger.error(f"Task failed with error for user {self.user_id}: {e}")
active_tasks.discard(task)
async def _process_search_queue(self):
"""Medium priority queue for search operations"""
while True:
try:
data = await self.search_queue.get()
request_id = data.get('requestId')
query = data.get('query', '').strip()
attempt_count = data.get('attemptCount', 0)
# logger.info(f"Processing search request for user {self.user_id}, attempt={attempt_count}")
if not query:
logger.warning(f"Empty query received in request from user {self.user_id}: {data}")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': 'No search query provided'
}
else:
try:
search_result = await self.shared_api.search_video(
query,
attempt_count=attempt_count
)
if search_result:
# logger.info(f"Search successful for user {self.user_id}, query '{query}'")
result = {
'action': 'search',
'requestId': request_id,
'success': True,
'result': search_result
}
else:
# logger.warning(f"No results found for user {self.user_id}, query '{query}'")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': 'No results found'
}
except Exception as e:
logger.error(f"Search error for user {self.user_id}, (attempt {attempt_count}): {str(e)}")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': f'Search error: {str(e)}'
}
await self.ws.send_json(result)
# Update metrics
self.request_counts['search'] += 1
self.last_request_times['search'] = time.time()
except Exception as e:
logger.error(f"Error in search queue processor for user {self.user_id}: {str(e)}")
try:
error_response = {
'action': 'search',
'requestId': data.get('requestId') if 'data' in locals() else None,
'success': False,
'error': f'Internal server error: {str(e)}'
}
await self.ws.send_json(error_response)
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
if 'search_queue' in self.__dict__:
self.search_queue.task_done()
async def _process_simulation_queue(self):
"""Dedicated queue for video simulation requests"""
while True:
try:
data = await self.simulation_queue.get()
request_id = data.get('requestId')
# Extract parameters from the request
video_id = data.get('video_id', '')
original_title = data.get('original_title', '')
original_description = data.get('original_description', '')
current_description = data.get('current_description', '')
condensed_history = data.get('condensed_history', '')
evolution_count = data.get('evolution_count', 0)
chat_messages = data.get('chat_messages', '')
logger.info(f"Processing video simulation for user {self.user_id}, video_id={video_id}, evolution_count={evolution_count}")
# Validate required parameters
if not original_title or not original_description or not current_description:
result = {
'action': 'simulate',
'requestId': request_id,
'success': False,
'error': 'Missing required parameters'
}
else:
try:
# Call the simulate method in the API
simulation_result = await self.shared_api.simulate(
original_title=original_title,
original_description=original_description,
current_description=current_description,
condensed_history=condensed_history,
evolution_count=evolution_count,
chat_messages=chat_messages
)
result = {
'action': 'simulate',
'requestId': request_id,
'success': True,
'evolved_description': simulation_result['evolved_description'],
'condensed_history': simulation_result['condensed_history']
}
except Exception as e:
logger.error(f"Error simulating video for user {self.user_id}, video_id={video_id}: {str(e)}")
result = {
'action': 'simulate',
'requestId': request_id,
'success': False,
'error': f'Simulation error: {str(e)}'
}
await self.ws.send_json(result)
# Update metrics
self.request_counts['simulation'] += 1
self.last_request_times['simulation'] = time.time()
except Exception as e:
logger.error(f"Error in simulation queue processor for user {self.user_id}: {str(e)}")
try:
error_response = {
'action': 'simulate',
'requestId': data.get('requestId') if 'data' in locals() else None,
'success': False,
'error': f'Internal server error: {str(e)}'
}
await self.ws.send_json(error_response)
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
if 'simulation_queue' in self.__dict__:
self.simulation_queue.task_done()
async def process_generic_request(self, data: dict) -> None:
"""Handle general requests that don't fit into specialized queues"""
try:
request_id = data.get('requestId')
action = data.get('action')
def error_response(message: str):
return {
'action': action,
'requestId': request_id,
'success': False,
'error': message
}
if action == 'heartbeat':
# Include user role info in heartbeat response
await self.ws.send_json({
'action': 'heartbeat',
'requestId': request_id,
'success': True,
'user_role': self.user_role
})
elif action == 'get_user_role':
# Return the user role information
await self.ws.send_json({
'action': 'get_user_role',
'requestId': request_id,
'success': True,
'user_role': self.user_role
})
elif action == 'generate_caption':
title = data.get('params', {}).get('title')
description = data.get('params', {}).get('description')
if not title or not description:
await self.ws.send_json(error_response('Missing title or description'))
return
caption = await self.shared_api.generate_caption(title, description)
await self.ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'caption': caption
})
# evolve_description is now handled by the dedicated simulation queue processor
elif action == 'generate_video_thumbnail':
title = data.get('title', '') or data.get('params', {}).get('title', '')
description = data.get('description', '') or data.get('params', {}).get('description', '')
video_prompt_prefix = data.get('video_prompt_prefix', '') or data.get('params', {}).get('video_prompt_prefix', '')
options = data.get('options', {}) or data.get('params', {}).get('options', {})
if not title:
await self.ws.send_json(error_response('Missing title for thumbnail generation'))
return
# Ensure the options include the thumbnail flag
options['thumbnail'] = True
# Prioritize thumbnail generation with higher priority
options['priority'] = 'high'
# Add small size settings if not already specified
if 'width' not in options:
options['width'] = 512 # Default thumbnail width
if 'height' not in options:
options['height'] = 288 # Default 16:9 aspect ratio
if 'num_frames' not in options:
options['num_frames'] = 25 # 1 second @ 25fps
# Let the API know this is a thumbnail for a specific video
options['video_id'] = data.get('video_id', f"thumbnail-{request_id}")
logger.info(f"Generating thumbnail for video {options['video_id']} for user {self.user_id}")
try:
# Generate the thumbnail
thumbnail_data = await self.shared_api.generate_video_thumbnail(
title, description, video_prompt_prefix, options, self.user_role
)
# Respond with appropriate format based on the parameter names used in the request
if 'thumbnailUrl' in data or 'thumbnailUrl' in data.get('params', {}):
# Legacy format using thumbnailUrl
await self.ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'thumbnailUrl': thumbnail_data or "",
})
else:
# New format using thumbnail
await self.ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'thumbnail': thumbnail_data,
})
except Exception as e:
logger.error(f"Error generating thumbnail: {str(e)}")
await self.ws.send_json(error_response(f"Thumbnail generation failed: {str(e)}"))
# Handle deprecated thumbnail actions
elif action == 'generate_thumbnail' or action == 'old_generate_thumbnail':
# Redirect to video thumbnail generation
logger.warning(f"Deprecated thumbnail action '{action}' used, redirecting to generate_video_thumbnail")
# Extract parameters
title = data.get('title', '') or data.get('params', {}).get('title', '')
description = data.get('description', '') or data.get('params', {}).get('description', '')
if not title or not description:
await self.ws.send_json(error_response('Missing title or description'))
return
# Create a new request with the correct action
new_request = {
'action': 'generate_video_thumbnail',
'requestId': request_id,
'title': title,
'description': description,
'options': {
'width': 512,
'height': 288,
'thumbnail': True,
'video_id': f"thumbnail-{request_id}"
}
}
# Process with the new action
await self.process_generic_request(new_request)
else:
await self.ws.send_json(error_response(f'Unknown action: {action}'))
except Exception as e:
logger.error(f"Error processing generic request for user {self.user_id}: {str(e)}")
try:
await self.ws.send_json({
'action': data.get('action'),
'requestId': data.get('requestId'),
'success': False,
'error': f'Internal server error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
class SessionManager:
"""
Manages all active user sessions and shared resources.
"""
def __init__(self):
self.sessions = {}
self.shared_api = VideoGenerationAPI() # Single instance for shared resources
self.session_lock = asyncio.Lock()
async def create_session(self, user_id: str, user_role: str, ws: web.WebSocketResponse) -> UserSession:
"""Create a new user session"""
async with self.session_lock:
# Create a new session for this user
session = UserSession(user_id, user_role, ws, self.shared_api)
await session.start()
self.sessions[user_id] = session
return session
async def delete_session(self, user_id: str) -> None:
"""Delete a user session and clean up resources"""
async with self.session_lock:
if user_id in self.sessions:
session = self.sessions[user_id]
await session.stop()
del self.sessions[user_id]
logger.info(f"Deleted session for user {user_id}")
def get_session(self, user_id: str) -> UserSession:
"""Get a user session if it exists"""
return self.sessions.get(user_id)
async def close_all_sessions(self) -> None:
"""Close all active sessions (used during shutdown)"""
async with self.session_lock:
for user_id, session in list(self.sessions.items()):
await session.stop()
self.sessions.clear()
logger.info("Closed all active sessions")
@property
def session_count(self) -> int:
"""Get the number of active sessions"""
return len(self.sessions)
def get_session_stats(self) -> Dict:
"""Get statistics about active sessions"""
stats = {
'total_sessions': len(self.sessions),
'by_role': {
'anon': 0,
'normal': 0,
'pro': 0,
'admin': 0
},
'requests': {
'chat': 0,
'video': 0,
'search': 0,
'simulation': 0
}
}
for session in self.sessions.values():
stats['by_role'][session.user_role] += 1
stats['requests']['chat'] += session.request_counts['chat']
stats['requests']['video'] += session.request_counts['video']
stats['requests']['search'] += session.request_counts['search']
stats['requests']['simulation'] += session.request_counts['simulation']
return stats