File size: 3,692 Bytes
dd07930
 
 
 
 
3dde046
dd07930
3dde046
dd07930
 
 
c46d8cf
dd07930
 
 
c46d8cf
 
 
 
 
 
 
 
 
 
 
 
 
dd07930
c46d8cf
dd07930
13c3439
c46d8cf
13c3439
 
 
 
dd07930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c46d8cf
 
 
 
 
3dde046
 
 
 
 
dd07930
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# server.py
import asyncio
import uuid
from typing import AsyncGenerator, Dict, Tuple, Any, Optional
from dataclasses import dataclass
from quart import Quart, websocket, request, Response
from quart_schema import QuartSchema, validate_request, validate_response
from quart_cors import cors
import importlib.metadata
import secrets
import logging
import os

from broker import SessionBroker, SessionDoesNotExist, ClientRequest, ClientResponse, ClientError

# Configuraci贸n para Hugging Face Spaces
PORT = int(os.getenv('PORT', 7860))
TIMEOUT: int = int(os.getenv('TIMEOUT', 60))
LOG_LEVEL: int = getattr(logging, os.getenv('LOG_LEVEL', 'INFO'))
MAX_MESSAGE_SIZE: int = int(os.getenv('MAX_MESSAGE_SIZE', 16 * 1024 * 1024))
RATE_LIMIT: int = int(os.getenv('RATE_LIMIT', 100))
SESSION_TIMEOUT: int = int(os.getenv('SESSION_TIMEOUT', 3600))

# Configurar logging
logging.basicConfig(
    level=LOG_LEVEL,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

# Crear aplicaci贸n con CORS habilitado
app = Quart(__name__)
app = cors(app, 
    allow_origin=["https://*.hf.space", "https://*.huggingface.co"],
    allow_methods=["GET", "POST", "OPTIONS"],
    allow_headers=["Content-Type"],
    max_age=3600
)
QuartSchema(app)

broker = SessionBroker()

# Definici贸n de modelos de datos
@dataclass
class Status:
    status: str
    version: str

@dataclass
class Session:
    session_id: str

@dataclass
class Command:
    session_id: str
    command: str

@dataclass
class CommandResponse:
    return_code: int
    stdout: str
    stderr: str

@dataclass
class ErrorResponse:
    error: str

# Rutas API
@app.get("/status")
@validate_response(Status)
async def status() -> Status:
    return Status(status="OK", version=importlib.metadata.version('your-package-name'))

@app.websocket('/session')
async def session_handler():
    session_id = secrets.token_hex()
    app.logger.info(f"New session: {session_id}")
    await websocket.send_as(Session(session_id=session_id), Session)

    task = asyncio.ensure_future(_receive(session_id))
    try:
        async for request in broker.subscribe(session_id):
            app.logger.info(f"Sending request {request.request_id} to client.")
            await websocket.send_as(request, ClientRequest)
    finally:
        task.cancel()

async def _receive(session_id: str) -> None:
    while True:
        response = await websocket.receive_as(ClientResponse)
        app.logger.info(f"Received response for session {session_id}: {response}")
        await broker.receive_response(session_id, response)

@app.post('/command')
@validate_request(Command)
@validate_response(CommandResponse, 200)
@validate_response(ErrorResponse, 500)
async def command(data: Command) -> Tuple[CommandResponse | ErrorResponse, int]:
    try:
        response_data = await broker.send_request(
            session_id=data.session_id, 
            data={'action': 'command', 'command': data.command}, 
            timeout=TIMEOUT
        )
        response = CommandResponse(**response_data)
        return response, 200
    except SessionDoesNotExist:
        app.logger.warning(f"Invalid session ID: {data.session_id}")
        return ErrorResponse(error='Session does not exist.'), 500
    except ClientError as e:
        return ErrorResponse(error=e.message), 500
    except asyncio.TimeoutError:
        return ErrorResponse(error='Timeout waiting for client response.'), 500

# Ejecutar aplicaci贸n
def run():
    app.run(
        host='0.0.0.0',
        port=PORT,
        debug=False
    )

# Agregar un endpoint de health check
@app.route("/health")
async def health_check():
    return {"status": "healthy"}

if __name__ == "__main__":
    run()