Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 25,504 Bytes
2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 4b590f9 2e813e6 4b590f9 2e813e6 4b590f9 2e813e6 4b590f9 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 d7edecf 2e813e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 |
import asyncio
import logging
from typing import Dict, Set
from aiohttp import web, WSMsgType
import json
import time
import datetime
from api_core import VideoGenerationAPI
logger = logging.getLogger(__name__)
class UserSession:
"""
Represents a user's session with the API.
Each WebSocket connection gets its own session with separate queues and rate limits.
"""
def __init__(self, user_id: str, user_role: str, ws: web.WebSocketResponse, shared_api):
self.user_id = user_id
self.user_role = user_role
self.ws = ws
self.shared_api = shared_api # For shared resources like endpoint manager
# Create separate queues for this user session
self.chat_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.search_queue = asyncio.Queue()
self.simulation_queue = asyncio.Queue() # New queue for description evolution
# Track request counts and rate limits
self.request_counts = {
'chat': 0,
'video': 0,
'search': 0,
'simulation': 0 # New counter for simulation requests
}
# Last request timestamps for rate limiting
self.last_request_times = {
'chat': time.time(),
'video': time.time(),
'search': time.time(),
'simulation': time.time() # New timestamp for simulation requests
}
# Session creation time
self.created_at = time.time()
self.background_tasks = []
async def start(self):
"""Start all the queue processors for this session"""
# Start background tasks for handling different request types
self.background_tasks = [
asyncio.create_task(self._process_chat_queue()),
asyncio.create_task(self._process_video_queue()),
asyncio.create_task(self._process_search_queue()),
asyncio.create_task(self._process_simulation_queue()) # New worker for simulation requests
]
logger.info(f"Started session for user {self.user_id} with role {self.user_role}")
async def stop(self):
"""Stop all background tasks for this session"""
for task in self.background_tasks:
task.cancel()
try:
# Wait for tasks to complete cancellation
await asyncio.gather(*self.background_tasks, return_exceptions=True)
except asyncio.CancelledError:
pass
logger.info(f"Stopped session for user {self.user_id}")
async def _process_chat_queue(self):
"""High priority queue for chat operations"""
while True:
data = await self.chat_queue.get()
try:
if data['action'] == 'join_chat':
result = await self.shared_api.handle_join_chat(data, self.ws)
elif data['action'] == 'chat_message':
result = await self.shared_api.handle_chat_message(data, self.ws)
elif data['action'] == 'leave_chat':
result = await self.shared_api.handle_leave_chat(data, self.ws)
# Redirect thumbnail requests to process_generic_request for consistent handling
elif data['action'] == 'generate_video_thumbnail':
# Pass to the generic request handler to maintain consistent logic
await self.process_generic_request(data)
# Skip normal response handling since process_generic_request already sends a response
self.chat_queue.task_done()
continue
else:
raise ValueError(f"Unknown chat action: {data['action']}")
await self.ws.send_json(result)
# Update metrics
self.request_counts['chat'] += 1
self.last_request_times['chat'] = time.time()
except Exception as e:
logger.error(f"Error processing chat request for user {self.user_id}: {e}")
try:
await self.ws.send_json({
'action': data['action'],
'requestId': data.get('requestId'),
'success': False,
'error': f'Chat error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
self.chat_queue.task_done()
async def _process_video_queue(self):
"""Process multiple video generation requests in parallel for this user"""
from api_config import VIDEO_ROUND_ROBIN_ENDPOINT_URLS
active_tasks = set()
# Set a per-user concurrent limit based on role
max_concurrent = len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS)
if self.user_role == 'anon':
max_concurrent = min(2, max_concurrent) # Limit anonymous users
elif self.user_role == 'normal':
max_concurrent = min(4, max_concurrent) # Standard users
# Pro and admin can use all endpoints
async def process_single_request(data):
try:
title = data.get('title', '')
description = data.get('description', '')
video_prompt_prefix = data.get('video_prompt_prefix', '')
options = data.get('options', {})
# Pass the user role to generate_video
video_data = await self.shared_api.generate_video(
title, description, video_prompt_prefix, options, self.user_role
)
result = {
'action': 'generate_video',
'requestId': data.get('requestId'),
'success': True,
'video': video_data,
}
await self.ws.send_json(result)
# Update metrics
self.request_counts['video'] += 1
self.last_request_times['video'] = time.time()
except Exception as e:
logger.error(f"Error processing video request for user {self.user_id}: {e}")
try:
await self.ws.send_json({
'action': 'generate_video',
'requestId': data.get('requestId'),
'success': False,
'error': f'Video generation error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
active_tasks.discard(asyncio.current_task())
while True:
# Clean up completed tasks
active_tasks = {task for task in active_tasks if not task.done()}
# Start new tasks if we have capacity
while len(active_tasks) < max_concurrent:
try:
# Use try_get to avoid blocking if queue is empty
data = await asyncio.wait_for(self.video_queue.get(), timeout=0.1)
# Create and start new task
task = asyncio.create_task(process_single_request(data))
active_tasks.add(task)
except asyncio.TimeoutError:
# No items in queue, break inner loop
break
except Exception as e:
logger.error(f"Error creating video generation task for user {self.user_id}: {e}")
break
# Wait a short time before checking queue again
await asyncio.sleep(0.1)
# Handle any completed tasks' errors
for task in list(active_tasks):
if task.done():
try:
await task
except Exception as e:
logger.error(f"Task failed with error for user {self.user_id}: {e}")
active_tasks.discard(task)
async def _process_search_queue(self):
"""Medium priority queue for search operations"""
while True:
try:
data = await self.search_queue.get()
request_id = data.get('requestId')
query = data.get('query', '').strip()
attempt_count = data.get('attemptCount', 0)
# logger.info(f"Processing search request for user {self.user_id}, attempt={attempt_count}")
if not query:
logger.warning(f"Empty query received in request from user {self.user_id}: {data}")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': 'No search query provided'
}
else:
try:
search_result = await self.shared_api.search_video(
query,
attempt_count=attempt_count
)
if search_result:
# logger.info(f"Search successful for user {self.user_id}, query '{query}'")
result = {
'action': 'search',
'requestId': request_id,
'success': True,
'result': search_result
}
else:
# logger.warning(f"No results found for user {self.user_id}, query '{query}'")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': 'No results found'
}
except Exception as e:
logger.error(f"Search error for user {self.user_id}, (attempt {attempt_count}): {str(e)}")
result = {
'action': 'search',
'requestId': request_id,
'success': False,
'error': f'Search error: {str(e)}'
}
await self.ws.send_json(result)
# Update metrics
self.request_counts['search'] += 1
self.last_request_times['search'] = time.time()
except Exception as e:
logger.error(f"Error in search queue processor for user {self.user_id}: {str(e)}")
try:
error_response = {
'action': 'search',
'requestId': data.get('requestId') if 'data' in locals() else None,
'success': False,
'error': f'Internal server error: {str(e)}'
}
await self.ws.send_json(error_response)
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
if 'search_queue' in self.__dict__:
self.search_queue.task_done()
async def _process_simulation_queue(self):
"""Dedicated queue for video simulation requests"""
while True:
try:
data = await self.simulation_queue.get()
request_id = data.get('requestId')
# Extract parameters from the request
video_id = data.get('video_id', '')
original_title = data.get('original_title', '')
original_description = data.get('original_description', '')
current_description = data.get('current_description', '')
condensed_history = data.get('condensed_history', '')
evolution_count = data.get('evolution_count', 0)
chat_messages = data.get('chat_messages', '')
logger.info(f"Processing video simulation for user {self.user_id}, video_id={video_id}, evolution_count={evolution_count}")
# Validate required parameters
if not original_title or not original_description or not current_description:
result = {
'action': 'simulate',
'requestId': request_id,
'success': False,
'error': 'Missing required parameters'
}
else:
try:
# Call the simulate method in the API
simulation_result = await self.shared_api.simulate(
original_title=original_title,
original_description=original_description,
current_description=current_description,
condensed_history=condensed_history,
evolution_count=evolution_count,
chat_messages=chat_messages
)
result = {
'action': 'simulate',
'requestId': request_id,
'success': True,
'evolved_description': simulation_result['evolved_description'],
'condensed_history': simulation_result['condensed_history']
}
except Exception as e:
logger.error(f"Error simulating video for user {self.user_id}, video_id={video_id}: {str(e)}")
result = {
'action': 'simulate',
'requestId': request_id,
'success': False,
'error': f'Simulation error: {str(e)}'
}
await self.ws.send_json(result)
# Update metrics
self.request_counts['simulation'] += 1
self.last_request_times['simulation'] = time.time()
except Exception as e:
logger.error(f"Error in simulation queue processor for user {self.user_id}: {str(e)}")
try:
error_response = {
'action': 'simulate',
'requestId': data.get('requestId') if 'data' in locals() else None,
'success': False,
'error': f'Internal server error: {str(e)}'
}
await self.ws.send_json(error_response)
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
finally:
if 'simulation_queue' in self.__dict__:
self.simulation_queue.task_done()
async def process_generic_request(self, data: dict) -> None:
"""Handle general requests that don't fit into specialized queues"""
try:
request_id = data.get('requestId')
action = data.get('action')
def error_response(message: str):
return {
'action': action,
'requestId': request_id,
'success': False,
'error': message
}
if action == 'heartbeat':
# Include user role info in heartbeat response
await self.ws.send_json({
'action': 'heartbeat',
'requestId': request_id,
'success': True,
'user_role': self.user_role
})
elif action == 'get_user_role':
# Return the user role information
await self.ws.send_json({
'action': 'get_user_role',
'requestId': request_id,
'success': True,
'user_role': self.user_role
})
elif action == 'generate_caption':
title = data.get('params', {}).get('title')
description = data.get('params', {}).get('description')
if not title or not description:
await self.ws.send_json(error_response('Missing title or description'))
return
caption = await self.shared_api.generate_caption(title, description)
await self.ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'caption': caption
})
# evolve_description is now handled by the dedicated simulation queue processor
elif action == 'generate_video_thumbnail':
title = data.get('title', '') or data.get('params', {}).get('title', '')
description = data.get('description', '') or data.get('params', {}).get('description', '')
video_prompt_prefix = data.get('video_prompt_prefix', '') or data.get('params', {}).get('video_prompt_prefix', '')
options = data.get('options', {}) or data.get('params', {}).get('options', {})
if not title:
await self.ws.send_json(error_response('Missing title for thumbnail generation'))
return
# Ensure the options include the thumbnail flag
options['thumbnail'] = True
# Prioritize thumbnail generation with higher priority
options['priority'] = 'high'
# Add small size settings if not already specified
if 'width' not in options:
options['width'] = 512 # Default thumbnail width
if 'height' not in options:
options['height'] = 288 # Default 16:9 aspect ratio
if 'num_frames' not in options:
options['num_frames'] = 25 # 1 second @ 25fps
# Let the API know this is a thumbnail for a specific video
options['video_id'] = data.get('video_id', f"thumbnail-{request_id}")
logger.info(f"Generating thumbnail for video {options['video_id']} for user {self.user_id}")
try:
# Generate the thumbnail
thumbnail_data = await self.shared_api.generate_video_thumbnail(
title, description, video_prompt_prefix, options, self.user_role
)
# Respond with appropriate format based on the parameter names used in the request
if 'thumbnailUrl' in data or 'thumbnailUrl' in data.get('params', {}):
# Legacy format using thumbnailUrl
await self.ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'thumbnailUrl': thumbnail_data or "",
})
else:
# New format using thumbnail
await self.ws.send_json({
'action': action,
'requestId': request_id,
'success': True,
'thumbnail': thumbnail_data,
})
except Exception as e:
logger.error(f"Error generating thumbnail: {str(e)}")
await self.ws.send_json(error_response(f"Thumbnail generation failed: {str(e)}"))
# Handle deprecated thumbnail actions
elif action == 'generate_thumbnail' or action == 'old_generate_thumbnail':
# Redirect to video thumbnail generation
logger.warning(f"Deprecated thumbnail action '{action}' used, redirecting to generate_video_thumbnail")
# Extract parameters
title = data.get('title', '') or data.get('params', {}).get('title', '')
description = data.get('description', '') or data.get('params', {}).get('description', '')
if not title or not description:
await self.ws.send_json(error_response('Missing title or description'))
return
# Create a new request with the correct action
new_request = {
'action': 'generate_video_thumbnail',
'requestId': request_id,
'title': title,
'description': description,
'options': {
'width': 512,
'height': 288,
'thumbnail': True,
'video_id': f"thumbnail-{request_id}"
}
}
# Process with the new action
await self.process_generic_request(new_request)
else:
await self.ws.send_json(error_response(f'Unknown action: {action}'))
except Exception as e:
logger.error(f"Error processing generic request for user {self.user_id}: {str(e)}")
try:
await self.ws.send_json({
'action': data.get('action'),
'requestId': data.get('requestId'),
'success': False,
'error': f'Internal server error: {str(e)}'
})
except Exception as send_error:
logger.error(f"Error sending error response: {send_error}")
class SessionManager:
"""
Manages all active user sessions and shared resources.
"""
def __init__(self):
self.sessions = {}
self.shared_api = VideoGenerationAPI() # Single instance for shared resources
self.session_lock = asyncio.Lock()
async def create_session(self, user_id: str, user_role: str, ws: web.WebSocketResponse) -> UserSession:
"""Create a new user session"""
async with self.session_lock:
# Create a new session for this user
session = UserSession(user_id, user_role, ws, self.shared_api)
await session.start()
self.sessions[user_id] = session
return session
async def delete_session(self, user_id: str) -> None:
"""Delete a user session and clean up resources"""
async with self.session_lock:
if user_id in self.sessions:
session = self.sessions[user_id]
await session.stop()
del self.sessions[user_id]
logger.info(f"Deleted session for user {user_id}")
def get_session(self, user_id: str) -> UserSession:
"""Get a user session if it exists"""
return self.sessions.get(user_id)
async def close_all_sessions(self) -> None:
"""Close all active sessions (used during shutdown)"""
async with self.session_lock:
for user_id, session in list(self.sessions.items()):
await session.stop()
self.sessions.clear()
logger.info("Closed all active sessions")
@property
def session_count(self) -> int:
"""Get the number of active sessions"""
return len(self.sessions)
def get_session_stats(self) -> Dict:
"""Get statistics about active sessions"""
stats = {
'total_sessions': len(self.sessions),
'by_role': {
'anon': 0,
'normal': 0,
'pro': 0,
'admin': 0
},
'requests': {
'chat': 0,
'video': 0,
'search': 0,
'simulation': 0
}
}
for session in self.sessions.values():
stats['by_role'][session.user_role] += 1
stats['requests']['chat'] += session.request_counts['chat']
stats['requests']['video'] += session.request_counts['video']
stats['requests']['search'] += session.request_counts['search']
stats['requests']['simulation'] += session.request_counts['simulation']
return stats |