Spaces:
Runtime error
Runtime error
File size: 6,546 Bytes
dd07930 76ca3e9 dd07930 76ca3e9 53b7d1b 76ca3e9 1d7545f 76ca3e9 49a61bb 76ca3e9 dd07930 76ca3e9 49a61bb 76ca3e9 49a61bb 76ca3e9 dd07930 49a61bb dd07930 53b7d1b 26b4d36 362d092 26b4d36 76ca3e9 dd07930 76ca3e9 dd07930 76ca3e9 dd07930 53b7d1b 1d7545f 53b7d1b 1d7545f dd07930 49a61bb dd07930 362d092 dd07930 362d092 dd07930 76ca3e9 dd07930 76ca3e9 dd07930 76ca3e9 dd07930 76ca3e9 dd07930 49a61bb dd07930 49a61bb |
|
# server.py
from dataclasses import dataclass, asdict
import secrets
import logging
import asyncio
import json
import yaml
from pathlib import Path
from typing import Tuple
from quart import Quart, websocket, request, send_from_directory, redirect
from quart_schema import QuartSchema, validate_request, validate_response
from quart_cors import cors
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
from broker import SessionBroker, SessionDoesNotExist, ClientRequest, ClientResponse, ClientError
# Configuraci贸n
VERSION = "1.0.0" # Versi贸n de la API
TIMEOUT: int = 40
LOG_LEVEL: int = logging.DEBUG
TRUSTED_HOSTS: list[str] = ["127.0.0.1", "10.16.38.136", "10.16.3.13", "10.16.13.73"]
# Create app
app = Quart(__name__)
app = cors(app,
allow_origin=["https://*.hf.space", "https://*.huggingface.co"],
allow_methods=["GET", "POST", "OPTIONS"],
allow_headers=["Content-Type"],
max_age=3600
)
# Cargar OpenAPI spec
OPENAPI_PATH = Path(__file__).parent / "openapi.yaml"
with open(OPENAPI_PATH) as f:
openapi_spec = yaml.safe_load(f)
# Configurar Quart Schema con la especificaci贸n OpenAPI
schema = QuartSchema(app)
schema.update_openapi(openapi_spec)
app.asgi_app = ProxyHeadersMiddleware(app.asgi_app, trusted_hosts=TRUSTED_HOSTS)
app.logger.setLevel(LOG_LEVEL)
broker = SessionBroker()
# Modelos de datos
@dataclass
class Status:
status: str
version: str
@dataclass
class Session:
session_id: str
@dataclass
class Command:
session_id: str
command: str
@dataclass
class Read:
session_id: str
path: str
@dataclass
class Write:
session_id: str
path: str
content: str
@dataclass
class CommandResponse:
return_code: int
stdout: str
stderr: str
@dataclass
class ReadResponse:
content: str
@dataclass
class WriteResponse:
size: int
@dataclass
class ErrorResponse:
error: str
# Rutas API
@app.route('/')
async def root():
return await send_from_directory('static', 'swagger.html')
@app.route('/docs')
async def docs():
return redirect('/')
@app.get("/status")
@validate_response(Status)
async def status() -> Status:
return Status(status="OK", version=VERSION)
@app.websocket('/session')
async def session_handler():
try:
session_id = secrets.token_hex()
app.logger.info(f"{websocket.remote_addr} - NEW SESSION - {session_id}")
await websocket.send_as(Session(session_id=session_id), Session)
task = None
try:
task = asyncio.ensure_future(_receive(session_id))
async for request in broker.subscribe(session_id):
app.logger.info(f"{websocket.remote_addr} - REQUEST - {session_id} - {json.dumps(asdict(request))}")
await websocket.send_as(request, ClientRequest)
except websockets.exceptions.ConnectionClosed:
app.logger.warning(f"{websocket.remote_addr} - CONNECTION CLOSED - {session_id}")
except Exception as e:
app.logger.error(f"{websocket.remote_addr} - ERROR - {session_id} - {str(e)}")
raise
finally:
if task is not None:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
app.logger.error(f"Error in session handler: {str(e)}")
raise
async def _receive(session_id: str) -> None:
try:
while True:
response = await websocket.receive_as(ClientResponse)
app.logger.info(f"{websocket.remote_addr} - RESPONSE - {session_id} - {json.dumps(asdict(response))}")
await broker.receive_response(session_id, response)
except websockets.exceptions.ConnectionClosed:
app.logger.warning(f"{websocket.remote_addr} - RECEIVE CONNECTION CLOSED - {session_id}")
except Exception as e:
app.logger.error(f"{websocket.remote_addr} - RECEIVE ERROR - {session_id} - {str(e)}")
raise
@app.post('/command')
@validate_request(Command)
@validate_response(CommandResponse, 200)
@validate_response(ErrorResponse, 500)
async def command(data: Command) -> Tuple[CommandResponse | ErrorResponse, int]:
try:
response = CommandResponse(**await broker.send_request(
data.session_id,
{'action': 'command', 'command': data.command},
timeout=TIMEOUT))
return response, 200
except SessionDoesNotExist:
app.logger.warning(f"{request.remote_addr} - INVALID SESSION ID - {repr(data.session_id)}")
return ErrorResponse('Session does not exist.'), 500
except ClientError as e:
return ErrorResponse(e.message), 500
except asyncio.TimeoutError:
return ErrorResponse('Timeout when waiting for client.'), 500
@app.post('/read')
@validate_request(Read)
@validate_response(ReadResponse, 200)
@validate_response(ErrorResponse, 500)
async def read(data: Read) -> Tuple[ReadResponse | ErrorResponse, int]:
try:
response = ReadResponse(**await broker.send_request(
data.session_id,
{'action': 'read', 'path': data.path},
timeout=TIMEOUT))
return response, 200
except SessionDoesNotExist:
app.logger.warning(f"{request.remote_addr} - INVALID SESSION ID - {repr(data.session_id)}")
return ErrorResponse('Session does not exist.'), 500
except ClientError as e:
return ErrorResponse(e.message), 500
except asyncio.TimeoutError:
return ErrorResponse('Timeout when waiting for client.'), 500
@app.post('/write')
@validate_request(Write)
@validate_response(WriteResponse, 200)
@validate_response(ErrorResponse, 500)
async def write(data: Write) -> Tuple[WriteResponse | ErrorResponse, int]:
try:
response = WriteResponse(**await broker.send_request(
data.session_id,
{'action': 'write', 'path': data.path, 'content': data.content},
timeout=TIMEOUT))
return response, 200
except SessionDoesNotExist:
app.logger.warning(f"{request.remote_addr} - INVALID SESSION ID - {repr(data.session_id)}")
return ErrorResponse('Session does not exist.'), 500
except ClientError as e:
return ErrorResponse(e.message), 500
except asyncio.TimeoutError:
return ErrorResponse('Timeout when waiting for client.'), 500
@app.route("/health")
async def health_check():
return {"status": "healthy"}
def run():
app.run(host='0.0.0.0', port=7860)
if __name__ == "__main__":
run() |