from typing import Annotated, Any, Generator from pathlib import Path from gymnasium.wrappers.record_video import RecordVideo from litrl.env.make import make from litrl.common.agent import RandomAgent from litrl.env.typing import SingleAgentId from fastapi import Depends, FastAPI, Request, status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from litrl.env.connect_four import Board from loguru import logger from fastapi.responses import StreamingResponse, FileResponse from moviepy.editor import VideoFileClip from src.app_state import AppState from src.typing import CpuConfig def create_app() -> FastAPI: app = FastAPI() @app.post("/", response_model=int) def bot_action( board: Board, cpuConfig: CpuConfig, app_state: Annotated[AppState, Depends(dependency=AppState)], ) -> int: app_state.set_config(cpu_config=cpuConfig) app_state.set_board(board=board) return app_state.get_action() @app.post(path=f"/game", response_model=str) def bot_action( env_id: SingleAgentId, ) -> str: env = make(id=env_id, render_mode="rgb_array") env = RecordVideo( env=env, video_folder="tmp", ) env.reset(seed=123) agent = RandomAgent[Any, Any]() terminated, truncated = False, False while not (terminated or truncated): action = agent.get_action(env=env) _, _, terminated, truncated, _ = env.step(action=action) env.render() env.video_recorder.close() def iterfile()-> Generator[bytes, Any, None]: with Path(env.video_recorder.path).open(mode="rb") as env_file: yield from env_file return StreamingResponse(content=iterfile(), media_type="video/mp4") @app.get(path=f"/gif") def bot_action( env_id: SingleAgentId, ) -> str: env = make(id=env_id, render_mode="rgb_array") env = RecordVideo( env=env, video_folder="tmp", ) env.reset(seed=123) agent = RandomAgent[Any, Any]() terminated, truncated = False, False while not (terminated or truncated): action = agent.get_action(env=env) _, _, terminated, truncated, _ = env.step(action=action) env.render() env.video_recorder.close() mp4_path = Path(env.video_recorder.path) video_clip = VideoFileClip(str(mp4_path)) gif_path = mp4_path.with_suffix(".gif") video_clip.write_gif(str(gif_path))#, fps=30) # TODO check fps return FileResponse(str(gif_path), media_type="image/gif") @app.exception_handler(exc_class_or_status_code=RequestValidationError) async def validation_exception_handler( request: Request, exc: RequestValidationError ) -> JSONResponse: logger.debug(f"url: {request.url}") if hasattr(request, "_body"): logger.debug(f"body: {request._body}") logger.debug(f"header: {request.headers}") logger.error(f"{request}: {exc}") exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") content = {"status_code": 10422, "message": exc_str, "data": None} return JSONResponse( content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY ) app.add_middleware( middleware_class=CORSMiddleware, allow_origins="*", allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) return app if __name__ == "__main__": import uvicorn import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level="info") server = uvicorn.Server(config=config) server.run()