c-gohlke commited on
Commit
7c66cbc
·
1 Parent(s): b90dbee

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. src/app.py +117 -0
src/app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Any, Generator
2
+ from pathlib import Path
3
+ from gymnasium.wrappers.record_video import RecordVideo
4
+ from litrl.env.make import make
5
+ from litrl.common.agent import RandomAgent
6
+ from litrl.env.typing import SingleAgentId
7
+ from fastapi import Depends, FastAPI, Request, status
8
+ from fastapi.exceptions import RequestValidationError
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse
11
+ from litrl.env.connect_four import Board
12
+ from loguru import logger
13
+ from fastapi.responses import StreamingResponse
14
+ from src.app_state import AppState
15
+ from src.typing import CpuConfig
16
+ from src.huggingface.huggingface_client import HuggingFaceClient
17
+
18
+ def stream_mp4(mp4_path: Path) -> StreamingResponse:
19
+ def iter_file()-> Generator[bytes, Any, None]:
20
+ with mp4_path.open(mode="rb") as env_file:
21
+ yield from env_file
22
+
23
+ return StreamingResponse(content=iter_file(), media_type="video/mp4")
24
+
25
+
26
+ def create_app() -> FastAPI:
27
+ app = FastAPI()
28
+
29
+ @app.post("/", response_model=int)
30
+ def bot_action(
31
+ board: Board,
32
+ cpuConfig: CpuConfig,
33
+ app_state: Annotated[AppState, Depends(dependency=AppState)],
34
+ ) -> int:
35
+ app_state.set_config(cpu_config=cpuConfig)
36
+ app_state.set_board(board=board)
37
+ return app_state.get_action()
38
+
39
+ @app.post(path=f"/game", response_model=str)
40
+ def bot_action(
41
+ env_id: SingleAgentId,
42
+ ) -> str:
43
+ env = RecordVideo(
44
+ env=make(id=env_id, render_mode="rgb_array"),
45
+ video_folder="tmp",
46
+ )
47
+ env.reset(seed=123)
48
+ agent = RandomAgent[Any, Any]()
49
+ terminated, truncated = False, False
50
+ while not (terminated or truncated):
51
+ action = agent.get_action(env=env)
52
+ _, _, terminated, truncated, _ = env.step(action=action)
53
+ env.render()
54
+ env.video_recorder.close()
55
+ return stream_mp4(mp4_path=Path(env.video_recorder.path))
56
+
57
+ @app.get(path=f"/hfmp4")
58
+ def fh_stream(
59
+ env_id: SingleAgentId,
60
+ hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
61
+ ) -> StreamingResponse:
62
+ hf_client.mp4_paths[env_id]
63
+ return stream_mp4(mp4_path=hf_client.mp4_paths[env_id])
64
+
65
+ @app.get(path=f"/mp4")
66
+ def bot_action(
67
+ env_id: SingleAgentId,
68
+ ) -> str:
69
+ env = make(id=env_id, render_mode="rgb_array")
70
+ env = RecordVideo(
71
+ env=env,
72
+ video_folder="tmp",
73
+ )
74
+ env.reset(seed=123)
75
+ agent = RandomAgent[Any, Any]()
76
+ terminated, truncated = False, False
77
+ while not (terminated or truncated):
78
+ action = agent.get_action(env=env)
79
+ _, _, terminated, truncated, _ = env.step(action=action)
80
+ env.render()
81
+ env.video_recorder.close()
82
+ return stream_mp4(mp4_path=Path(env.video_recorder.path))
83
+
84
+ @app.exception_handler(exc_class_or_status_code=RequestValidationError)
85
+ async def validation_exception_handler(
86
+ request: Request, exc: RequestValidationError
87
+ ) -> JSONResponse:
88
+ logger.debug(f"url: {request.url}")
89
+ if hasattr(request, "_body"):
90
+ logger.debug(f"body: {request._body}")
91
+ logger.debug(f"header: {request.headers}")
92
+ logger.error(f"{request}: {exc}")
93
+ exc_str = f"{exc}".replace("\n", " ").replace(" ", " ")
94
+ content = {"status_code": 10422, "message": exc_str, "data": None}
95
+ return JSONResponse(
96
+ content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
97
+ )
98
+
99
+ app.add_middleware(
100
+ middleware_class=CORSMiddleware,
101
+ allow_origins="*",
102
+ allow_credentials=True,
103
+ allow_methods=["*"],
104
+ allow_headers=["*"],
105
+ )
106
+ return app
107
+
108
+ if __name__ == "__main__":
109
+ import uvicorn
110
+ import argparse
111
+ parser = argparse.ArgumentParser()
112
+ parser.add_argument("--host", type=str, default="0.0.0.0")
113
+ parser.add_argument("--port", type=int, default=8000)
114
+ args = parser.parse_args()
115
+ config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level="info")
116
+ server = uvicorn.Server(config=config)
117
+ server.run()