from collections.abc import Generator from pathlib import Path from typing import Annotated, Any from fastapi import Depends, FastAPI, Request, status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse from gymnasium.wrappers.record_video import RecordVideo from loguru import logger from pydantic import BaseModel from litrl.algo.mcts.agent import MCTSAgent from litrl.common.agent import RandomAgent from litrl.env.connect_four import Board, ConnectFour from litrl.env.make import make from litrl.env.typing import GymId from src.app_state import AppState from src.huggingface.huggingface_client import HuggingFaceClient from src.typing import CpuConfig def stream_mp4(mp4_path: Path) -> StreamingResponse: def iter_file() -> Generator[bytes, Any, None]: with mp4_path.open(mode="rb") as env_file: yield from env_file return StreamingResponse(content=iter_file(), media_type="video/mp4") ObservationType = list[Board] class GridResponseType(BaseModel): grid: ObservationType done: bool class BotResponseType(GridResponseType): action: int def get_app_state() -> AppState: return AppState() def step(env: ConnectFour, action: int) -> GridResponseType: env.step(action) return observe(env) def observe(env: ConnectFour) -> GridResponseType: obs = env.observe("player_1") return GridResponseType( grid=obs["observation"].tolist(), done=bool(env.terminations[env.agent_selection] or env.truncations[env.agent_selection]), # TODO why needed? ) def create_app() -> FastAPI: # noqa: C901 # TODO move to routes app = FastAPI() @app.get("/") async def redirect_to_docs() -> RedirectResponse: return RedirectResponse("/docs") @app.post(path="/connect_four/play", response_model=GridResponseType) def endpoint_play( action: int, app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> GridResponseType: response = step(app_state.env, action) app_state.inform_action(action=action) return response @app.get(path="/connect_four/observe", response_model=GridResponseType) def endpoint_observe( app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> GridResponseType: return observe(app_state.env) @app.post(path="/connect_four/bot_play", response_model=BotResponseType) def endpoint_bot_play( cpu_config: CpuConfig, app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> BotResponseType: app_state.set_config(cpu_config) action = app_state.get_action() response = step(app_state.env, action) app_state.inform_action(action=action) return BotResponseType( grid=response.grid, done=response.done, action=action, ) @app.get(path="/connect_four/bot_progress", response_model=float) def endpoint_bot_progress( app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> float: if isinstance(app_state.agent, MCTSAgent): if app_state.cpu_config.simulations is None: raise ValueError if app_state.agent.mcts is None: raise ValueError return float( app_state.agent.mcts.root.visits / app_state.cpu_config.simulations, ) # TODO why not recognized as float? return 1.0 @app.get(path="/connect_four/reset", response_model=GridResponseType) def endpoint_reset( app_state: Annotated[AppState, Depends(dependency=get_app_state)], ) -> GridResponseType: app_state.env.reset() return observe(app_state.env) @app.get(path="/get_huggingface_video") def endpoint_get_huggingface_video( env_id: GymId, hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)], ) -> StreamingResponse: hf_client.mp4_paths[env_id] return stream_mp4(mp4_path=hf_client.mp4_paths[env_id]) @app.get(path="/get_env_video") def endpoint_get_env_video( env_id: GymId, ) -> StreamingResponse: env = make(id=env_id, render_mode="rgb_array") env = RecordVideo( env=env, video_folder="tmp", ) env.reset(seed=123) if env.video_recorder is None: msg = "env.video_recorder is None" raise ValueError(msg) agent = RandomAgent[Any, Any]() terminated, truncated = False, False while not (terminated or truncated): action = agent.get_action(env=env) # type: ignore[arg-type] _, _, terminated, truncated, _ = env.step(action=action) env.render() env.video_recorder.close() return stream_mp4(mp4_path=Path(env.video_recorder.path)) @app.exception_handler(exc_class_or_status_code=RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: logger.debug(f"url: {request.url}") if hasattr(request, "_body"): logger.debug(f"body: {request._body.decode()}") # noqa: SLF001 logger.debug(f"header: {request.headers}") logger.error(f"{request}: {exc}") exc_str = str(exc).replace("\n", " ").replace(" ", " ") content = {"status_code": 10422, "message": exc_str, "data": None} return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) app.add_middleware( middleware_class=CORSMiddleware, allow_origins="*", allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) return app