aitube2 / api.py
jbilcke-hf's picture
jbilcke-hf HF Staff
working on the chat system
d7edecf
import asyncio
import json
import logging
import os
import pathlib
import time
import uuid
from aiohttp import web, WSMsgType
from typing import Dict, Any
from api_core import VideoGenerationAPI
from api_session import SessionManager
from api_metrics import MetricsTracker
from api_config import *
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Create global session and metrics managers
session_manager = SessionManager()
metrics_tracker = MetricsTracker()
# Dictionary to track connected anonymous clients by IP address
anon_connections = {}
anon_connection_lock = asyncio.Lock()
async def status_handler(request: web.Request) -> web.Response:
"""Handler for API status endpoint"""
api = session_manager.shared_api
# Get current busy status of all endpoints
endpoint_statuses = []
for ep in api.endpoint_manager.endpoints:
endpoint_statuses.append({
'id': ep.id,
'url': ep.url,
'busy': ep.busy,
'last_used': ep.last_used,
'error_count': ep.error_count,
'error_until': ep.error_until
})
# Get session statistics
session_stats = session_manager.get_session_stats()
# Get metrics
api_metrics = metrics_tracker.get_metrics()
return web.json_response({
'product': PRODUCT_NAME,
'version': PRODUCT_VERSION,
'maintenance_mode': MAINTENANCE_MODE,
'available_endpoints': len(VIDEO_ROUND_ROBIN_ENDPOINT_URLS),
'endpoint_status': endpoint_statuses,
'active_endpoints': sum(1 for ep in endpoint_statuses if not ep['busy'] and ('error_until' not in ep or ep['error_until'] < time.time())),
'active_sessions': session_stats,
'metrics': api_metrics
})
async def metrics_handler(request: web.Request) -> web.Response:
"""Handler for detailed metrics endpoint (protected)"""
# Check for API key in header or query param
auth_header = request.headers.get('Authorization', '')
api_key = None
if auth_header.startswith('Bearer '):
api_key = auth_header[7:]
else:
api_key = request.query.get('key')
# Validate API key (using SECRET_TOKEN as the API key)
if not api_key or api_key != SECRET_TOKEN:
return web.json_response({
'error': 'Unauthorized'
}, status=401)
# Get detailed metrics
detailed_metrics = metrics_tracker.get_detailed_metrics()
return web.json_response(detailed_metrics)
async def websocket_handler(request: web.Request) -> web.WebSocketResponse:
# Check if maintenance mode is enabled
if MAINTENANCE_MODE:
# Return an error response indicating maintenance mode
return web.json_response({
'error': 'Server is in maintenance mode',
'maintenance': True
}, status=503) # 503 Service Unavailable
ws = web.WebSocketResponse(
max_msg_size=1024*1024*20, # 20MB max message size
timeout=30.0 # we want to keep things tight and short
)
await ws.prepare(request)
# Get the Hugging Face token from query parameters
hf_token = request.query.get('hf_token', '')
# Generate a unique user ID for this connection
user_id = str(uuid.uuid4())
# Validate the token and determine the user role
user_role = await session_manager.shared_api.validate_user_token(hf_token)
logger.info(f"User {user_id} connected with role: {user_role}")
# Get client IP address
peername = request.transport.get_extra_info('peername')
if peername is not None:
client_ip = peername[0]
else:
client_ip = request.headers.get('X-Forwarded-For', 'unknown').split(',')[0].strip()
logger.info(f"Client {user_id} connecting from IP: {client_ip} with role: {user_role}")
# Check for anonymous user connection limits
if user_role == 'anon':
async with anon_connection_lock:
# Track this connection
anon_connections[client_ip] = anon_connections.get(client_ip, 0) + 1
# Store the IP so we can clean up later
ws.client_ip = client_ip
# Log multiple connections from same IP but don't restrict them
if anon_connections[client_ip] > 1:
logger.info(f"Multiple anonymous connections from IP {client_ip}: {anon_connections[client_ip]} connections")
# Store the user role in the websocket for easy access
ws.user_role = user_role
ws.user_id = user_id
# Register with metrics
metrics_tracker.register_session(user_id, client_ip)
# Create a new session for this user
user_session = await session_manager.create_session(user_id, user_role, ws)
try:
async for msg in ws:
if msg.type == WSMsgType.TEXT:
try:
data = json.loads(msg.data)
action = data.get('action')
# Check for rate limiting
request_type = 'other'
if action in ['join_chat', 'leave_chat', 'chat_message']:
request_type = 'chat'
elif action in ['generate_video']:
request_type = 'video'
elif action == 'search':
request_type = 'search'
elif action == 'simulate':
request_type = 'simulation'
# Record the request for metrics
await metrics_tracker.record_request(user_id, client_ip, request_type, user_role)
# Check rate limits (except for admins)
if user_role != 'admin' and await metrics_tracker.is_rate_limited(user_id, request_type, user_role):
await ws.send_json({
'action': action,
'requestId': data.get('requestId'),
'success': False,
'error': f'Rate limit exceeded for {request_type} requests. Please try again later.'
})
continue
# Route requests to appropriate queues
if action in ['join_chat', 'leave_chat', 'chat_message']:
await user_session.chat_queue.put(data)
elif action in ['generate_video']:
await user_session.video_queue.put(data)
elif action == 'search':
await user_session.search_queue.put(data)
elif action == 'simulate':
await user_session.simulation_queue.put(data)
else:
await user_session.process_generic_request(data)
except Exception as e:
logger.error(f"Error processing WebSocket message for user {user_id}: {str(e)}")
await ws.send_json({
'action': data.get('action') if 'data' in locals() else 'unknown',
'success': False,
'error': f'Error processing message: {str(e)}'
})
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
finally:
# Cleanup session
await session_manager.delete_session(user_id)
# Cleanup anonymous connection tracking
if getattr(ws, 'user_role', None) == 'anon' and hasattr(ws, 'client_ip'):
client_ip = ws.client_ip
async with anon_connection_lock:
if client_ip in anon_connections:
anon_connections[client_ip] = max(0, anon_connections[client_ip] - 1)
if anon_connections[client_ip] == 0:
del anon_connections[client_ip]
logger.info(f"Anonymous connection from {client_ip} closed. Remaining: {anon_connections.get(client_ip, 0)}")
# Unregister from metrics
metrics_tracker.unregister_session(user_id, client_ip)
logger.info(f"Connection closed for user {user_id}")
return ws
async def init_app() -> web.Application:
app = web.Application(
client_max_size=1024**2*20 # 20MB max size
)
# Add cleanup logic
async def cleanup(app):
logger.info("Shutting down server, closing all sessions...")
await session_manager.close_all_sessions()
app.on_shutdown.append(cleanup)
# Add routes
app.router.add_get('/ws', websocket_handler)
app.router.add_get('/api/status', status_handler)
app.router.add_get('/api/metrics', metrics_handler)
# Set up static file serving
# Define the path to the public directory
public_path = pathlib.Path(__file__).parent / 'build' / 'web'
if not public_path.exists():
public_path.mkdir(parents=True, exist_ok=True)
# Set up static file serving with proper security considerations
async def static_file_handler(request):
# Get the path from the request (removing leading /)
path_parts = request.path.lstrip('/').split('/')
# Convert to safe path to prevent path traversal attacks
safe_path = public_path.joinpath(*path_parts)
# Make sure the path is within the public directory (prevent directory traversal)
try:
safe_path = safe_path.resolve()
if not str(safe_path).startswith(str(public_path.resolve())):
return web.HTTPForbidden(text="Access denied")
except (ValueError, FileNotFoundError):
return web.HTTPNotFound()
# If path is a directory, look for index.html
if safe_path.is_dir():
safe_path = safe_path / 'index.html'
# Check if the file exists
if not safe_path.exists() or not safe_path.is_file():
# If not found, serve index.html (for SPA routing)
safe_path = public_path / 'index.html'
if not safe_path.exists():
return web.HTTPNotFound()
# Determine content type based on file extension
content_type = 'text/plain'
ext = safe_path.suffix.lower()
if ext == '.html':
content_type = 'text/html'
elif ext == '.js':
content_type = 'application/javascript'
elif ext == '.css':
content_type = 'text/css'
elif ext in ('.jpg', '.jpeg'):
content_type = 'image/jpeg'
elif ext == '.png':
content_type = 'image/png'
elif ext == '.gif':
content_type = 'image/gif'
elif ext == '.svg':
content_type = 'image/svg+xml'
elif ext == '.json':
content_type = 'application/json'
# Return the file with appropriate headers
return web.FileResponse(safe_path, headers={'Content-Type': content_type})
# Add catch-all route for static files (lower priority than API routes)
app.router.add_get('/{path:.*}', static_file_handler)
return app
if __name__ == '__main__':
app = asyncio.run(init_app())
web.run_app(app, host='0.0.0.0', port=8080)