Spaces:
Build error
Build error
from typing import Annotated, Any, Generator | |
from pathlib import Path | |
from gymnasium.wrappers.record_video import RecordVideo | |
import numpy as np | |
from litrl.env.make import make | |
from litrl.common.agent import RandomAgent | |
from litrl.env.typing import SingleAgentId | |
from fastapi import Depends, FastAPI, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from litrl.env.connect_four import Board | |
from loguru import logger | |
from fastapi.responses import StreamingResponse, RedirectResponse | |
from src.app_state import AppState | |
from src.typing import CpuConfig | |
from pydantic import BaseModel | |
from src.huggingface.huggingface_client import HuggingFaceClient | |
from litrl.env.connect_four import ConnectFour | |
def stream_mp4(mp4_path: Path) -> StreamingResponse: | |
def iter_file()-> Generator[bytes, Any, None]: | |
with mp4_path.open(mode="rb") as env_file: | |
yield from env_file | |
return StreamingResponse(content=iter_file(), media_type="video/mp4") | |
ObservationType = list[list[list[int]]] | |
class GridResponseType(BaseModel): | |
grid: ObservationType | |
done: bool | |
def step(env: ConnectFour, action: int)->GridResponseType: | |
env.step(action) | |
return observe(env) | |
def observe(env: ConnectFour)->GridResponseType: | |
obs = env.observe("player_1") | |
done = env.terminations[env.agent_selection] or env.truncations[env.agent_selection] | |
return {"grid": obs['observation'].tolist(), "done": done} | |
def create_app() -> FastAPI: | |
app = FastAPI() | |
async def to_docs(): | |
return RedirectResponse("/docs") | |
def endpoint_play( | |
action: int, | |
app_state: Annotated[AppState, Depends(dependency=AppState)], | |
) -> GridResponseType: | |
return step(app_state.env, action) | |
def endpoint_observe( | |
app_state: Annotated[AppState, Depends(dependency=AppState)], | |
) -> GridResponseType: | |
return observe(app_state.env) | |
def endpoint_bot_play( | |
cpu_config: CpuConfig, | |
app_state: Annotated[AppState, Depends(dependency=AppState)], | |
) -> GridResponseType: | |
app_state.set_config(cpu_config) | |
action = app_state.get_action() | |
return step(app_state.env, action) | |
def endpoint_reset( | |
app_state: Annotated[AppState, Depends(dependency=AppState)], | |
) -> GridResponseType: | |
app_state.env.reset() | |
return observe(app_state.env) | |
def endpoint_get_huggingface_video( | |
env_id: SingleAgentId, | |
hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)], | |
) -> StreamingResponse: | |
hf_client.mp4_paths[env_id] | |
return stream_mp4(mp4_path=hf_client.mp4_paths[env_id]) | |
def endpoint_get_env_video( | |
env_id: SingleAgentId, | |
) -> StreamingResponse: | |
env = make(id=env_id, render_mode="rgb_array") | |
env = RecordVideo( | |
env=env, | |
video_folder="tmp", | |
) | |
env.reset(seed=123) | |
agent = RandomAgent[Any, Any]() | |
terminated, truncated = False, False | |
while not (terminated or truncated): | |
action = agent.get_action(env=env) | |
_, _, terminated, truncated, _ = env.step(action=action) | |
env.render() | |
env.video_recorder.close() | |
return stream_mp4(mp4_path=Path(env.video_recorder.path)) | |
async def validation_exception_handler( | |
request: Request, exc: RequestValidationError | |
) -> JSONResponse: | |
logger.debug(f"url: {request.url}") | |
if hasattr(request, "_body"): | |
logger.debug(f"body: {request._body}") | |
logger.debug(f"header: {request.headers}") | |
logger.error(f"{request}: {exc}") | |
exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") | |
content = {"status_code": 10422, "message": exc_str, "data": None} | |
return JSONResponse( | |
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY | |
) | |
app.add_middleware( | |
middleware_class=CORSMiddleware, | |
allow_origins="*", | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
return app | |
if __name__ == "__main__": | |
import uvicorn | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--log", type=str, default="info") | |
parser.add_argument('--reload', action='store_true', help='Reload flag') | |
args = parser.parse_args() | |
config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level=args.log, reload=args.reload) | |
server = uvicorn.Server(config=config) | |
server.run() |