import sys from pathlib import Path from typing import Any, Generator if sys.version_info[:2] >= (3, 11): from typing import Annotated else: from typing_extensions import Annotated from fastapi import Depends, FastAPI, Request, status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from gymnasium.wrappers.record_video import RecordVideo from loguru import logger from litrl.algo.mcts.agent import MCTSAgent from litrl.common.agent import RandomAgent from litrl.env.make import make from litrl.env.typing import GymId from src.app_state import AppState from src.huggingface.huggingface_client import HuggingFaceClient from src.typing import BotResponseType, CpuConfig, GridResponseType def stream_mp4(mp4_path: Path) -> StreamingResponse: def iter_file() -> Generator[bytes, Any, None]: with mp4_path.open(mode="rb") as env_file: yield from env_file return StreamingResponse(content=iter_file(), media_type="video/mp4") def get_app_state() -> AppState: return AppState() def create_app() -> FastAPI: # noqa: C901 # TODO move to routes app = FastAPI() @app.get("/") async def redirect_to_docs() -> RedirectResponse: return RedirectResponse("/docs") @app.post(path="/connect_four/play", response_model=GridResponseType) def endpoint_play( action: int, app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> GridResponseType: return app_state.step(action) @app.get(path="/connect_four/observe", response_model=GridResponseType) def endpoint_observe( app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> GridResponseType: return app_state.observe() @app.post(path="/connect_four/bot_play", response_model=BotResponseType) def endpoint_bot_play( cpu_config: CpuConfig, app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> BotResponseType: app_state.set_config(cpu_config) action = app_state.get_action() response = app_state.step(action) return BotResponseType( grid=response.grid, done=response.done, action=action, ) @app.get(path="/connect_four/bot_progress", response_model=float) def endpoint_bot_progress( app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> float: if isinstance(app_state.agent, MCTSAgent): if app_state.cpu_config.simulations is None: raise ValueError if app_state.agent.mcts is None: return 1.0 return float( app_state.agent.mcts.root.visits / app_state.cpu_config.simulations, ) # TODO why not recognized as float? return 1.0 @app.get(path="/connect_four/reset", response_model=GridResponseType) def endpoint_reset( app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> GridResponseType: return app_state.reset() @app.get(path="/get_huggingface_video") def endpoint_get_huggingface_video( env_id: GymId, hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)], ) -> StreamingResponse: hf_client.mp4_paths[env_id] return stream_mp4(mp4_path=hf_client.mp4_paths[env_id]) @app.get(path="/get_env_video") def endpoint_get_env_video( env_id: GymId, ) -> StreamingResponse: env = make(id=env_id, render_mode="rgb_array") env = RecordVideo( env=env, video_folder="tmp", ) env.reset(seed=123) if env.video_recorder is None: msg = "env.video_recorder is None" raise ValueError(msg) agent = RandomAgent[Any, Any]() terminated, truncated = False, False while not (terminated or truncated): action = agent.get_action(env=env) # type: ignore[arg-type] _, _, terminated, truncated, _ = env.step(action=action) env.render() env.video_recorder.close() return stream_mp4(mp4_path=Path(env.video_recorder.path)) @app.exception_handler(exc_class_or_status_code=RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: logger.debug(f"url: {request.url}") if hasattr(request, "_body"): logger.debug(f"body: {request._body.decode()}") # noqa: SLF001 logger.debug(f"header: {request.headers}") logger.error(f"{request}: {exc}") exc_str = str(exc).replace("\n", " ").replace(" ", " ") content = {"status_code": 10422, "message": exc_str, "data": None} return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) app.add_middleware( middleware_class=CORSMiddleware, allow_origins="*", allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) return app