Spaces:
Building
on
CPU Upgrade
Building
on
CPU Upgrade
import asyncio | |
import json | |
import logging | |
import os | |
import pathlib | |
import time | |
import uuid | |
from aiohttp import web, WSMsgType | |
from typing import Dict, Any | |
from api_core import VideoGenerationAPI | |
from api_session import SessionManager | |
from api_metrics import MetricsTracker | |
from api_config import * | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Create global session and metrics managers | |
session_manager = SessionManager() | |
metrics_tracker = MetricsTracker() | |
# Dictionary to track connected anonymous clients by IP address | |
anon_connections = {} | |
anon_connection_lock = asyncio.Lock() | |
async def status_handler(request: web.Request) -> web.Response: | |
"""Handler for API status endpoint""" | |
api = session_manager.shared_api | |
# Get current busy status of all endpoints | |
endpoint_statuses = [] | |
for ep in api.endpoint_manager.endpoints: | |
endpoint_statuses.append({ | |
'id': ep.id, | |
'url': ep.url, | |
'busy': ep.busy, | |
'last_used': ep.last_used, | |
'error_count': ep.error_count, | |
'error_until': ep.error_until | |
}) | |
# Get session statistics | |
session_stats = session_manager.get_session_stats() | |
# Get metrics | |
api_metrics = metrics_tracker.get_metrics() | |
return web.json_response({ | |
'product': PRODUCT_NAME, | |
'version': PRODUCT_VERSION, | |
'maintenance_mode': MAINTENANCE_MODE, | |
'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS), | |
'endpoint_status': endpoint_statuses, | |
'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())), | |
'active_sessions': session_stats, | |
'metrics': api_metrics | |
}) | |
async def metrics_handler(request: web.Request) -> web.Response: | |
"""Handler for detailed metrics endpoint (protected)""" | |
# Check for API key in header or query param | |
auth_header = request.headers.get('Authorization', '') | |
api_key = None | |
if auth_header.startswith('Bearer '): | |
api_key = auth_header[7:] | |
else: | |
api_key = request.query.get('key') | |
# Validate API key (using SECRET_TOKEN as the API key) | |
if not api_key or api_key != SECRET_TOKEN: | |
return web.json_response({ | |
'error': 'Unauthorized' | |
}, status=401) | |
# Get detailed metrics | |
detailed_metrics = metrics_tracker.get_detailed_metrics() | |
return web.json_response(detailed_metrics) | |
async def websocket_handler(request: web.Request) -> web.WebSocketResponse: | |
# Check if maintenance mode is enabled | |
if MAINTENANCE_MODE: | |
# Return an error response indicating maintenance mode | |
return web.json_response({ | |
'error': 'Server is in maintenance mode', | |
'maintenance': True | |
}, status=503) # 503 Service Unavailable | |
ws = web.WebSocketResponse( | |
max_msg_size=1024*1024*20, # 20MB max message size | |
timeout=30.0 # we want to keep things tight and short | |
) | |
await ws.prepare(request) | |
# Get the Hugging Face token from query parameters | |
hf_token = request.query.get('hf_token', '') | |
# Generate a unique user ID for this connection | |
user_id = str(uuid.uuid4()) | |
# Validate the token and determine the user role | |
user_role = await session_manager.shared_api.validate_user_token(hf_token) | |
logger.info(f"User {user_id} connected with role: {user_role}") | |
# Get client IP address | |
peername = request.transport.get_extra_info('peername') | |
if peername is not None: | |
client_ip = peername[0] | |
else: | |
client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip() | |
logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}") | |
# Check for anonymous user connection limits | |
if user_role == 'anon': | |
async with anon_connection_lock: | |
# Track this connection | |
anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1 | |
# Store the IP so we can clean up later | |
ws.client_ip = client_ip | |
# Log multiple connections from same IP but don't restrict them | |
if anon_connections[client_ip] > 1: | |
logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections") | |
# Store the user role in the websocket for easy access | |
ws.user_role = user_role | |
ws.user_id = user_id | |
# Register with metrics | |
metrics_tracker.register_session(user_id, client_ip) | |
# Create a new session for this user | |
user_session = await session_manager.create_session(user_id, user_role, ws) | |
try: | |
async for msg in ws: | |
if msg.type == WSMsgType.TEXT: | |
try: | |
data = json.loads(msg.data) | |
action = data.get('action') | |
# Check for rate limiting | |
request_type = 'other' | |
if action in ['join_chat', 'leave_chat', 'chat_message']: | |
request_type = 'chat' | |
elif action in ['generate_video']: | |
request_type = 'video' | |
elif action == 'search': | |
request_type = 'search' | |
elif action == 'simulate': | |
request_type = 'simulation' | |
# Record the request for metrics | |
await metrics_tracker.record_request(user_id, client_ip, request_type, user_role) | |
# Check rate limits (except for admins) | |
if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role): | |
await ws.send_json({ | |
'action': action, | |
'requestId': data.get('requestId'), | |
'success': False, | |
'error': f'Rate limit exceeded for {request_type} requests. Please try again later.' | |
}) | |
continue | |
# Route requests to appropriate queues | |
if action in ['join_chat', 'leave_chat', 'chat_message']: | |
await user_session.chat_queue.put(data) | |
elif action in ['generate_video']: | |
await user_session.video_queue.put(data) | |
elif action == 'search': | |
await user_session.search_queue.put(data) | |
elif action == 'simulate': | |
await user_session.simulation_queue.put(data) | |
else: | |
await user_session.process_generic_request(data) | |
except Exception as e: | |
logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}") | |
await ws.send_json({ | |
'action': data.get('action') if 'data' in locals() else 'unknown', | |
'success': False, | |
'error': f'Error processing message: {str(e)}' | |
}) | |
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): | |
break | |
finally: | |
# Cleanup session | |
await session_manager.delete_session(user_id) | |
# Cleanup anonymous connection tracking | |
if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'): | |
client_ip = ws.client_ip | |
async with anon_connection_lock: | |
if client_ip in anon_connections: | |
anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1) | |
if anon_connections[client_ip] == 0: | |
del anon_connections[client_ip] | |
logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}") | |
# Unregister from metrics | |
metrics_tracker.unregister_session(user_id, client_ip) | |
logger.info(f"Connection closed for user {user_id}") | |
return ws | |
async def init_app() -> web.Application: | |
app = web.Application( | |
client_max_size=1024**2*20 # 20MB max size | |
) | |
# Add cleanup logic | |
async def cleanup(app): | |
logger.info("Shutting down server, closing all sessions...") | |
await session_manager.close_all_sessions() | |
app.on_shutdown.append(cleanup) | |
# Add routes | |
app.router.add_get('/ws', websocket_handler) | |
app.router.add_get('/api/status', status_handler) | |
app.router.add_get('/api/metrics', metrics_handler) | |
# Set up static file serving | |
# Define the path to the public directory | |
public_path = pathlib.Path(__file__).parent / 'build' / 'web' | |
if not public_path.exists(): | |
public_path.mkdir(parents=True, exist_ok=True) | |
# Set up static file serving with proper security considerations | |
async def static_file_handler(request): | |
# Get the path from the request (removing leading /) | |
path_parts = request.path.lstrip('/').split('/') | |
# Convert to safe path to prevent path traversal attacks | |
safe_path = public_path.joinpath(*path_parts) | |
# Make sure the path is within the public directory (prevent directory traversal) | |
try: | |
safe_path = safe_path.resolve() | |
if not str(safe_path).startswith(str(public_path.resolve())): | |
return web.HTTPForbidden(text="Access denied") | |
except (ValueError, FileNotFoundError): | |
return web.HTTPNotFound() | |
# If path is a directory, look for index.html | |
if safe_path.is_dir(): | |
safe_path = safe_path / 'index.html' | |
# Check if the file exists | |
if not safe_path.exists() or not safe_path.is_file(): | |
# If not found, serve index.html (for SPA routing) | |
safe_path = public_path / 'index.html' | |
if not safe_path.exists(): | |
return web.HTTPNotFound() | |
# Determine content type based on file extension | |
content_type = 'text/plain' | |
ext = safe_path.suffix.lower() | |
if ext == '.html': | |
content_type = 'text/html' | |
elif ext == '.js': | |
content_type = 'application/javascript' | |
elif ext == '.css': | |
content_type = 'text/css' | |
elif ext in ('.jpg', '.jpeg'): | |
content_type = 'image/jpeg' | |
elif ext == '.png': | |
content_type = 'image/png' | |
elif ext == '.gif': | |
content_type = 'image/gif' | |
elif ext == '.svg': | |
content_type = 'image/svg+xml' | |
elif ext == '.json': | |
content_type = 'application/json' | |
# Return the file with appropriate headers | |
return web.FileResponse(safe_path, headers={'Content-Type': content_type}) | |
# Add catch-all route for static files (lower priority than API routes) | |
app.router.add_get('/{path:.*}', static_file_handler) | |
return app | |
if __name__ == '__main__': | |
app = asyncio.run(init_app()) | |
web.run_app(app, host='0.0.0.0', port=8080) |