Spaces:
Build error
Build error
import sys | |
from pathlib import Path | |
from typing import Any, Generator, List | |
if sys.version_info[:2] >= (3, 11): | |
from typing import Annotated | |
else: | |
from typing_extensions import Annotated | |
from fastapi import Depends, FastAPI, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse | |
from gymnasium.wrappers.record_video import RecordVideo | |
from loguru import logger | |
from pydantic import BaseModel | |
from litrl.algo.mcts.agent import MCTSAgent | |
from litrl.common.agent import RandomAgent | |
from litrl.env.connect_four import Board, ConnectFour | |
from litrl.env.make import make | |
from litrl.env.typing import GymId | |
from src.app_state import AppState | |
from src.huggingface.huggingface_client import HuggingFaceClient | |
from src.typing import CpuConfig | |
def stream_mp4(mp4_path: Path) -> StreamingResponse: | |
def iter_file() -> Generator[bytes, Any, None]: | |
with mp4_path.open(mode="rb") as env_file: | |
yield from env_file | |
return StreamingResponse(content=iter_file(), media_type="video/mp4") | |
ObservationType = List[Board] | |
class GridResponseType(BaseModel): | |
grid: ObservationType | |
done: bool | |
class BotResponseType(GridResponseType): | |
action: int | |
def get_app_state() -> AppState: | |
return AppState() | |
def step(env: ConnectFour, action: int) -> GridResponseType: | |
env.step(action) | |
return observe(env) | |
def observe(env: ConnectFour) -> GridResponseType: | |
obs = env.observe("player_1") | |
return GridResponseType( | |
grid=obs["observation"].tolist(), | |
done=bool(env.terminations[env.agent_selection] or env.truncations[env.agent_selection]), # TODO why needed? | |
) | |
def create_app() -> FastAPI: # noqa: C901 # TODO move to routes | |
app = FastAPI() | |
async def redirect_to_docs() -> RedirectResponse: | |
return RedirectResponse("/docs") | |
def endpoint_play( | |
action: int, | |
app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
) -> GridResponseType: | |
response = step(app_state.env, action) | |
app_state.inform_action(action=action) | |
return response | |
def endpoint_observe( | |
app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
) -> GridResponseType: | |
return observe(app_state.env) | |
def endpoint_bot_play( | |
cpu_config: CpuConfig, | |
app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
) -> BotResponseType: | |
app_state.set_config(cpu_config) | |
action = app_state.get_action() | |
response = step(app_state.env, action) | |
app_state.inform_action(action=action) | |
return BotResponseType( | |
grid=response.grid, | |
done=response.done, | |
action=action, | |
) | |
def endpoint_bot_progress( | |
app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
) -> float: | |
if isinstance(app_state.agent, MCTSAgent): | |
if app_state.cpu_config.simulations is None: | |
raise ValueError | |
if app_state.agent.mcts is None: | |
raise ValueError | |
return float( | |
app_state.agent.mcts.root.visits / app_state.cpu_config.simulations, | |
) # TODO why not recognized as float? | |
return 1.0 | |
def endpoint_reset( | |
app_state: Annotated[AppState, Depends(dependency=get_app_state)], | |
) -> GridResponseType: | |
app_state.env.reset() | |
app_state.inform_reset() | |
return observe(app_state.env) | |
def endpoint_get_huggingface_video( | |
env_id: GymId, | |
hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)], | |
) -> StreamingResponse: | |
hf_client.mp4_paths[env_id] | |
return stream_mp4(mp4_path=hf_client.mp4_paths[env_id]) | |
def endpoint_get_env_video( | |
env_id: GymId, | |
) -> StreamingResponse: | |
env = make(id=env_id, render_mode="rgb_array") | |
env = RecordVideo( | |
env=env, | |
video_folder="tmp", | |
) | |
env.reset(seed=123) | |
if env.video_recorder is None: | |
msg = "env.video_recorder is None" | |
raise ValueError(msg) | |
agent = RandomAgent[Any, Any]() | |
terminated, truncated = False, False | |
while not (terminated or truncated): | |
action = agent.get_action(env=env) # type: ignore[arg-type] | |
_, _, terminated, truncated, _ = env.step(action=action) | |
env.render() | |
env.video_recorder.close() | |
return stream_mp4(mp4_path=Path(env.video_recorder.path)) | |
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: | |
logger.debug(f"url: {request.url}") | |
if hasattr(request, "_body"): | |
logger.debug(f"body: {request._body.decode()}") # noqa: SLF001 | |
logger.debug(f"header: {request.headers}") | |
logger.error(f"{request}: {exc}") | |
exc_str = str(exc).replace("\n", " ").replace(" ", " ") | |
content = {"status_code": 10422, "message": exc_str, "data": None} | |
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) | |
app.add_middleware( | |
middleware_class=CORSMiddleware, | |
allow_origins="*", | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
return app | |