File size: 5,183 Bytes
5b344d4
7c66cbc
302ae2f
5b344d4
 
 
 
 
7011484
7c66cbc
 
 
7011484
 
7c66cbc
7011484
 
 
 
 
76c534f
 
302ae2f
7011484
7c66cbc
 
7011484
7c66cbc
 
 
 
 
7011484
acf3b96
 
 
 
7011484
7c66cbc
 
7011484
 
dc7242a
 
6e7d45d
 
 
5cd7fc9
6e7d45d
302ae2f
7011484
6e7d45d
 
5cd7fc9
6e7d45d
302ae2f
7011484
acf3b96
6e7d45d
 
acf3b96
 
6e7d45d
 
302ae2f
acf3b96
 
 
 
 
7c66cbc
7011484
 
5cd7fc9
7011484
 
 
 
acf3b96
302ae2f
7011484
acf3b96
 
7011484
 
6e7d45d
 
5cd7fc9
6e7d45d
302ae2f
7c66cbc
6e7d45d
 
7011484
7c66cbc
 
 
 
 
6e7d45d
 
7011484
dc7242a
7c66cbc
 
 
 
 
 
7011484
 
 
 
 
7c66cbc
 
 
7011484
7c66cbc
 
 
 
 
 
7011484
7c66cbc
 
7011484
7c66cbc
 
7011484
7c66cbc
7011484
7c66cbc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sys
from pathlib import Path
from typing import Any, Generator

if sys.version_info[:2] >= (3, 11):
    from typing import Annotated
else:
    from typing_extensions import Annotated

from fastapi import Depends, FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
from gymnasium.wrappers.record_video import RecordVideo
from loguru import logger

from litrl.algo.mcts.agent import MCTSAgent
from litrl.common.agent import RandomAgent
from litrl.env.make import make
from litrl.env.typing import GymId
from src.app_state import AppState
from src.huggingface.huggingface_client import HuggingFaceClient
from src.typing import BotResponseType, CpuConfig, GridResponseType


def stream_mp4(mp4_path: Path) -> StreamingResponse:
    def iter_file() -> Generator[bytes, Any, None]:
        with mp4_path.open(mode="rb") as env_file:
            yield from env_file

    return StreamingResponse(content=iter_file(), media_type="video/mp4")


def get_app_state() -> AppState:
    return AppState()


def create_app() -> FastAPI:  # noqa: C901 # TODO move to routes
    app = FastAPI()

    @app.get("/")
    async def redirect_to_docs() -> RedirectResponse:
        return RedirectResponse("/docs")

    @app.post(path="/connect_four/play", response_model=GridResponseType)
    def endpoint_play(
        action: int,
        app_state: Annotated[AppState, Depends(dependency=get_app_state)],
    ) -> GridResponseType:
        return app_state.step(action)

    @app.get(path="/connect_four/observe", response_model=GridResponseType)
    def endpoint_observe(
        app_state: Annotated[AppState, Depends(dependency=get_app_state)],
    ) -> GridResponseType:
        return app_state.observe()

    @app.post(path="/connect_four/bot_play", response_model=BotResponseType)
    def endpoint_bot_play(
        cpu_config: CpuConfig,
        app_state: Annotated[AppState, Depends(dependency=get_app_state)],
    ) -> BotResponseType:
        app_state.set_config(cpu_config)
        action = app_state.get_action()
        response = app_state.step(action)
        return BotResponseType(
            grid=response.grid,
            done=response.done,
            action=action,
        )

    @app.get(path="/connect_four/bot_progress", response_model=float)
    def endpoint_bot_progress(
        app_state: Annotated[AppState, Depends(dependency=get_app_state)],
    ) -> float:
        if isinstance(app_state.agent, MCTSAgent):
            if app_state.cpu_config.simulations is None:
                raise ValueError
            if app_state.agent.mcts is None:
                return 1.0
            return float(
                app_state.agent.mcts.root.visits / app_state.cpu_config.simulations,
            )  # TODO why not recognized as float?
        return 1.0

    @app.get(path="/connect_four/reset", response_model=GridResponseType)
    def endpoint_reset(
        app_state: Annotated[AppState, Depends(dependency=get_app_state)],
    ) -> GridResponseType:
        return app_state.reset()

    @app.get(path="/get_huggingface_video")
    def endpoint_get_huggingface_video(
        env_id: GymId,
        hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
    ) -> StreamingResponse:
        hf_client.mp4_paths[env_id]
        return stream_mp4(mp4_path=hf_client.mp4_paths[env_id])

    @app.get(path="/get_env_video")
    def endpoint_get_env_video(
        env_id: GymId,
    ) -> StreamingResponse:
        env = make(id=env_id, render_mode="rgb_array")
        env = RecordVideo(
            env=env,
            video_folder="tmp",
        )
        env.reset(seed=123)

        if env.video_recorder is None:
            msg = "env.video_recorder is None"
            raise ValueError(msg)

        agent = RandomAgent[Any, Any]()
        terminated, truncated = False, False
        while not (terminated or truncated):
            action = agent.get_action(env=env)  # type: ignore[arg-type]
            _, _, terminated, truncated, _ = env.step(action=action)
            env.render()
        env.video_recorder.close()
        return stream_mp4(mp4_path=Path(env.video_recorder.path))

    @app.exception_handler(exc_class_or_status_code=RequestValidationError)
    async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
        logger.debug(f"url: {request.url}")
        if hasattr(request, "_body"):
            logger.debug(f"body: {request._body.decode()}")  # noqa: SLF001
        logger.debug(f"header: {request.headers}")
        logger.error(f"{request}: {exc}")
        exc_str = str(exc).replace("\n", " ").replace("   ", " ")
        content = {"status_code": 10422, "message": exc_str, "data": None}
        return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)

    app.add_middleware(
        middleware_class=CORSMiddleware,
        allow_origins="*",
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )
    return app