Spaces:
Build error
Build error
from typing import Annotated, Any, Generator | |
from pathlib import Path | |
from gymnasium.wrappers.record_video import RecordVideo | |
from litrl.env.make import make | |
from litrl.common.agent import RandomAgent | |
from litrl.env.typing import SingleAgentId | |
from fastapi import Depends, FastAPI, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from litrl.env.connect_four import Board | |
from loguru import logger | |
from fastapi.responses import StreamingResponse | |
from src.app_state import AppState | |
from src.typing import CpuConfig | |
from src.huggingface.huggingface_client import HuggingFaceClient | |
def stream_mp4(mp4_path: Path) -> StreamingResponse: | |
def iter_file()-> Generator[bytes, Any, None]: | |
with mp4_path.open(mode="rb") as env_file: | |
yield from env_file | |
return StreamingResponse(content=iter_file(), media_type="video/mp4") | |
def create_app() -> FastAPI: | |
app = FastAPI() | |
def bot_action( | |
board: Board, | |
cpuConfig: CpuConfig, | |
app_state: Annotated[AppState, Depends(dependency=AppState)], | |
) -> int: | |
app_state.set_config(cpu_config=cpuConfig) | |
app_state.set_board(board=board) | |
return app_state.get_action() | |
def bot_action( | |
env_id: SingleAgentId, | |
) -> str: | |
env = RecordVideo( | |
env=make(id=env_id, render_mode="rgb_array"), | |
video_folder="tmp", | |
) | |
env.reset(seed=123) | |
agent = RandomAgent[Any, Any]() | |
terminated, truncated = False, False | |
while not (terminated or truncated): | |
action = agent.get_action(env=env) | |
_, _, terminated, truncated, _ = env.step(action=action) | |
env.render() | |
env.video_recorder.close() | |
return stream_mp4(mp4_path=Path(env.video_recorder.path)) | |
def fh_stream( | |
env_id: SingleAgentId, | |
hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)], | |
) -> StreamingResponse: | |
hf_client.mp4_paths[env_id] | |
return stream_mp4(mp4_path=hf_client.mp4_paths[env_id]) | |
def bot_action( | |
env_id: SingleAgentId, | |
) -> str: | |
env = make(id=env_id, render_mode="rgb_array") | |
env = RecordVideo( | |
env=env, | |
video_folder="tmp", | |
) | |
env.reset(seed=123) | |
agent = RandomAgent[Any, Any]() | |
terminated, truncated = False, False | |
while not (terminated or truncated): | |
action = agent.get_action(env=env) | |
_, _, terminated, truncated, _ = env.step(action=action) | |
env.render() | |
env.video_recorder.close() | |
return stream_mp4(mp4_path=Path(env.video_recorder.path)) | |
async def validation_exception_handler( | |
request: Request, exc: RequestValidationError | |
) -> JSONResponse: | |
logger.debug(f"url: {request.url}") | |
if hasattr(request, "_body"): | |
logger.debug(f"body: {request._body}") | |
logger.debug(f"header: {request.headers}") | |
logger.error(f"{request}: {exc}") | |
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") | |
content = {"status_code": 10422, "message": exc_str, "data": None} | |
return JSONResponse( | |
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY | |
) | |
app.add_middleware( | |
middleware_class=CORSMiddleware, | |
allow_origins="*", | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
return app | |
if __name__ == "__main__": | |
import uvicorn | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=8000) | |
args = parser.parse_args() | |
config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level="info") | |
server = uvicorn.Server(config=config) | |
server.run() |