c-gohlke's picture
Upload folder using huggingface_hub
3414b3b verified
raw
history blame
6.01 kB
import sys
from pathlib import Path
from typing import Any, Generator, List
if sys.version_info[:2] >= (3, 11):
from typing import Annotated
else:
from typing_extensions import Annotated
from fastapi import Depends, FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
from gymnasium.wrappers.record_video import RecordVideo
from loguru import logger
from pydantic import BaseModel
from litrl.algo.mcts.agent import MCTSAgent
from litrl.common.agent import RandomAgent
from litrl.env.connect_four import Board, ConnectFour
from litrl.env.make import make
from litrl.env.typing import GymId
from src.app_state import AppState
from src.huggingface.huggingface_client import HuggingFaceClient
from src.typing import CpuConfig
def stream_mp4(mp4_path: Path) -> StreamingResponse:
def iter_file() -> Generator[bytes, Any, None]:
with mp4_path.open(mode="rb") as env_file:
yield from env_file
return StreamingResponse(content=iter_file(), media_type="video/mp4")
ObservationType = List[Board]
class GridResponseType(BaseModel):
grid: ObservationType
done: bool
class BotResponseType(GridResponseType):
action: int
def get_app_state() -> AppState:
return AppState()
def step(env: ConnectFour, action: int) -> GridResponseType:
env.step(action)
return observe(env)
def observe(env: ConnectFour) -> GridResponseType:
obs = env.observe("player_1")
return GridResponseType(
grid=obs["observation"].tolist(),
done=bool(env.terminations[env.agent_selection] or env.truncations[env.agent_selection]), # TODO why needed?
)
def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
app = FastAPI()
@app.get("/")
async def redirect_to_docs() -> RedirectResponse:
return RedirectResponse("/docs")
@app.post(path="/connect_four/play", response_model=GridResponseType)
def endpoint_play(
action: int,
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
) -> GridResponseType:
response = step(app_state.env, action)
app_state.inform_action(action=action)
return response
@app.get(path="/connect_four/observe", response_model=GridResponseType)
def endpoint_observe(
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
) -> GridResponseType:
return observe(app_state.env)
@app.post(path="/connect_four/bot_play", response_model=BotResponseType)
def endpoint_bot_play(
cpu_config: CpuConfig,
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
) -> BotResponseType:
app_state.set_config(cpu_config)
action = app_state.get_action()
response = step(app_state.env, action)
app_state.inform_action(action=action)
return BotResponseType(
grid=response.grid,
done=response.done,
action=action,
)
@app.get(path="/connect_four/bot_progress", response_model=float)
def endpoint_bot_progress(
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
) -> float:
if isinstance(app_state.agent, MCTSAgent):
if app_state.cpu_config.simulations is None:
raise ValueError
if app_state.agent.mcts is None:
raise ValueError
return float(
app_state.agent.mcts.root.visits / app_state.cpu_config.simulations,
) # TODO why not recognized as float?
return 1.0
@app.get(path="/connect_four/reset", response_model=GridResponseType)
def endpoint_reset(
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
) -> GridResponseType:
app_state.env.reset()
app_state.inform_reset()
return observe(app_state.env)
@app.get(path="/get_huggingface_video")
def endpoint_get_huggingface_video(
env_id: GymId,
hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
) -> StreamingResponse:
hf_client.mp4_paths[env_id]
return stream_mp4(mp4_path=hf_client.mp4_paths[env_id])
@app.get(path="/get_env_video")
def endpoint_get_env_video(
env_id: GymId,
) -> StreamingResponse:
env = make(id=env_id, render_mode="rgb_array")
env = RecordVideo(
env=env,
video_folder="tmp",
)
env.reset(seed=123)
if env.video_recorder is None:
msg = "env.video_recorder is None"
raise ValueError(msg)
agent = RandomAgent[Any, Any]()
terminated, truncated = False, False
while not (terminated or truncated):
action = agent.get_action(env=env) # type: ignore[arg-type]
_, _, terminated, truncated, _ = env.step(action=action)
env.render()
env.video_recorder.close()
return stream_mp4(mp4_path=Path(env.video_recorder.path))
@app.exception_handler(exc_class_or_status_code=RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
logger.debug(f"url: {request.url}")
if hasattr(request, "_body"):
logger.debug(f"body: {request._body.decode()}") # noqa: SLF001
logger.debug(f"header: {request.headers}")
logger.error(f"{request}: {exc}")
exc_str = str(exc).replace("\n", " ").replace(" ", " ")
content = {"status_code": 10422, "message": exc_str, "data": None}
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
app.add_middleware(
middleware_class=CORSMiddleware,
allow_origins="*",
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return app