c-gohlke commited on
Commit
6e7d45d
·
verified ·
1 Parent(s): 0d6b8c7

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. src/app.py +50 -30
  2. src/app_state.py +22 -17
src/app.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Annotated, Any, Generator
2
  from pathlib import Path
3
  from gymnasium.wrappers.record_video import RecordVideo
 
4
  from litrl.env.make import make
5
  from litrl.common.agent import RandomAgent
6
  from litrl.env.typing import SingleAgentId
@@ -13,7 +14,9 @@ from loguru import logger
13
  from fastapi.responses import StreamingResponse, RedirectResponse
14
  from src.app_state import AppState
15
  from src.typing import CpuConfig
 
16
  from src.huggingface.huggingface_client import HuggingFaceClient
 
17
 
18
  def stream_mp4(mp4_path: Path) -> StreamingResponse:
19
  def iter_file()-> Generator[bytes, Any, None]:
@@ -22,6 +25,19 @@ def stream_mp4(mp4_path: Path) -> StreamingResponse:
22
 
23
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def create_app() -> FastAPI:
27
  app = FastAPI()
@@ -30,44 +46,45 @@ def create_app() -> FastAPI:
30
  async def to_docs():
31
  return RedirectResponse("/docs")
32
 
33
- @app.post("/", response_model=int)
34
- def bot_action(
35
- board: Board,
36
- cpuConfig: CpuConfig,
 
 
 
 
 
37
  app_state: Annotated[AppState, Depends(dependency=AppState)],
38
- ) -> int:
39
- app_state.set_config(cpu_config=cpuConfig)
40
- app_state.set_board(board=board)
41
- return app_state.get_action()
 
 
 
 
 
 
 
42
 
43
- @app.post(path=f"/game", response_model=str)
44
- def bot_action(
45
- env_id: SingleAgentId,
46
- ) -> str:
47
- env = RecordVideo(
48
- env=make(id=env_id, render_mode="rgb_array"),
49
- video_folder="tmp",
50
- )
51
- env.reset(seed=123)
52
- agent = RandomAgent[Any, Any]()
53
- terminated, truncated = False, False
54
- while not (terminated or truncated):
55
- action = agent.get_action(env=env)
56
- _, _, terminated, truncated, _ = env.step(action=action)
57
- env.render()
58
- env.video_recorder.close()
59
- return stream_mp4(mp4_path=Path(env.video_recorder.path))
60
 
61
- @app.get(path=f"/hfmp4")
62
- def fh_stream(
63
  env_id: SingleAgentId,
64
  hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
65
  ) -> StreamingResponse:
66
  hf_client.mp4_paths[env_id]
67
  return stream_mp4(mp4_path=hf_client.mp4_paths[env_id])
68
 
69
- @app.get(path=f"/mp4")
70
- def bot_action(
71
  env_id: SingleAgentId,
72
  ) -> StreamingResponse:
73
  env = make(id=env_id, render_mode="rgb_array")
@@ -115,7 +132,10 @@ if __name__ == "__main__":
115
  parser = argparse.ArgumentParser()
116
  parser.add_argument("--host", type=str, default="0.0.0.0")
117
  parser.add_argument("--port", type=int, default=7860)
 
 
118
  args = parser.parse_args()
119
- config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level="info")
 
120
  server = uvicorn.Server(config=config)
121
  server.run()
 
1
  from typing import Annotated, Any, Generator
2
  from pathlib import Path
3
  from gymnasium.wrappers.record_video import RecordVideo
4
+ import numpy as np
5
  from litrl.env.make import make
6
  from litrl.common.agent import RandomAgent
7
  from litrl.env.typing import SingleAgentId
 
14
  from fastapi.responses import StreamingResponse, RedirectResponse
15
  from src.app_state import AppState
16
  from src.typing import CpuConfig
17
+ from pydantic import BaseModel
18
  from src.huggingface.huggingface_client import HuggingFaceClient
19
+ from litrl.env.connect_four import ConnectFour
20
 
21
  def stream_mp4(mp4_path: Path) -> StreamingResponse:
22
  def iter_file()-> Generator[bytes, Any, None]:
 
25
 
26
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
27
 
28
+ ObservationType = list[list[list[int]]]
29
+ class GridResponseType(BaseModel):
30
+ grid: ObservationType
31
+ done: bool
32
+
33
+ def step(env: ConnectFour, action: int)->GridResponseType:
34
+ env.step(action)
35
+ return observe(env)
36
+
37
+ def observe(env: ConnectFour)->GridResponseType:
38
+ obs = env.observe("player_1")
39
+ done = env.terminations[env.agent_selection] or env.truncations[env.agent_selection]
40
+ return {"grid": obs['observation'].tolist(), "done": done}
41
 
42
  def create_app() -> FastAPI:
43
  app = FastAPI()
 
46
  async def to_docs():
47
  return RedirectResponse("/docs")
48
 
49
+ @app.post(path="/connect_four/play", response_model=GridResponseType)
50
+ def endpoint_play(
51
+ action: int,
52
+ app_state: Annotated[AppState, Depends(dependency=AppState)],
53
+ ) -> GridResponseType:
54
+ return step(app_state.env, action)
55
+
56
+ @app.get(path="/connect_four/observe", response_model=GridResponseType)
57
+ def endpoint_observe(
58
  app_state: Annotated[AppState, Depends(dependency=AppState)],
59
+ ) -> GridResponseType:
60
+ return observe(app_state.env)
61
+
62
+ @app.post(path="/connect_four/bot_play", response_model=GridResponseType)
63
+ def endpoint_bot_play(
64
+ cpu_config: CpuConfig,
65
+ app_state: Annotated[AppState, Depends(dependency=AppState)],
66
+ ) -> GridResponseType:
67
+ app_state.set_config(cpu_config)
68
+ action = app_state.get_action()
69
+ return step(app_state.env, action)
70
 
71
+ @app.get(path="/connect_four/reset", response_model=GridResponseType)
72
+ def endpoint_reset(
73
+ app_state: Annotated[AppState, Depends(dependency=AppState)],
74
+ ) -> GridResponseType:
75
+ app_state.env.reset()
76
+ return observe(app_state.env)
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ @app.get(path="/get_huggingface_video")
79
+ def endpoint_get_huggingface_video(
80
  env_id: SingleAgentId,
81
  hf_client: Annotated[HuggingFaceClient, Depends(dependency=HuggingFaceClient)],
82
  ) -> StreamingResponse:
83
  hf_client.mp4_paths[env_id]
84
  return stream_mp4(mp4_path=hf_client.mp4_paths[env_id])
85
 
86
+ @app.get(path="/get_env_video")
87
+ def endpoint_get_env_video(
88
  env_id: SingleAgentId,
89
  ) -> StreamingResponse:
90
  env = make(id=env_id, render_mode="rgb_array")
 
132
  parser = argparse.ArgumentParser()
133
  parser.add_argument("--host", type=str, default="0.0.0.0")
134
  parser.add_argument("--port", type=int, default=7860)
135
+ parser.add_argument("--log", type=str, default="info")
136
+ parser.add_argument('--reload', action='store_true', help='Reload flag')
137
  args = parser.parse_args()
138
+
139
+ config = uvicorn.Config(app=create_app(), host=args.host, port=args.port, log_level=args.log, reload=args.reload)
140
  server = uvicorn.Server(config=config)
141
  server.run()
src/app_state.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Any
2
 
3
  from litrl import make_multiagent
4
  from litrl.algo.mcts.agent import MCTSAgent
@@ -8,32 +8,42 @@ from litrl.algo.sac.agent import OnnxSacDeterministicAgent
8
  from litrl.common.agent import Agent, RandomMultiAgent
9
  from litrl.env.connect_four import Board
10
  from litrl.env.set_state import set_state
11
-
12
  from src.typing import AgentType, CpuConfig, RolloutPolicy
13
-
14
 
15
  class AppState:
16
- def __init__(self) -> None:
17
- self.env = make_multiagent(id="ConnectFour-v3", render_mode="human")
 
 
 
 
 
 
18
  self.env.reset(seed=123)
19
- self.agent: Agent[Any, Any] | None = None
20
- self.cpu_config: CpuConfig | None = None
21
 
 
 
 
 
 
 
 
 
 
 
22
  def set_board(self, board: Board) -> None:
23
  set_state(env=self.env, board=board)
24
 
25
  def set_config(self, cpu_config: CpuConfig) -> None:
26
  if (
27
- self.agent is None
28
- or self.cpu_config is None
29
- or cpu_config != self.cpu_config
30
  ):
31
  self.cpu_config = cpu_config
32
  self.set_agent()
33
 
34
  def create_rollout(self) -> Agent[Any, Any]:
35
- if self.cpu_config is None:
36
- raise ValueError("self.cpu_config is None")
37
  match self.cpu_config.rollout_policy:
38
  case None:
39
  return RandomMultiAgent()
@@ -47,9 +57,6 @@ class AppState:
47
  )
48
 
49
  def set_agent(self) -> None:
50
- if self.cpu_config is None:
51
- raise ValueError("self.cpu_config is None")
52
-
53
  match self.cpu_config.agent_type:
54
  case AgentType.MCTS:
55
  rollout_agent = self.create_rollout()
@@ -68,6 +75,4 @@ class AppState:
68
  )
69
 
70
  def get_action(self) -> int:
71
- if self.agent is None:
72
- raise ValueError("self.agent is None")
73
  return self.agent.get_action(env=self.env)
 
1
+ from typing import Any, Self
2
 
3
  from litrl import make_multiagent
4
  from litrl.algo.mcts.agent import MCTSAgent
 
8
  from litrl.common.agent import Agent, RandomMultiAgent
9
  from litrl.env.connect_four import Board
10
  from litrl.env.set_state import set_state
11
+ from loguru import logger
12
  from src.typing import AgentType, CpuConfig, RolloutPolicy
13
+ from litrl.env.connect_four import ConnectFour
14
 
15
  class AppState:
16
+ _instance: Self | None = None
17
+ env: ConnectFour
18
+ cpu_config: CpuConfig
19
+ agent: Agent[Any, Any]
20
+
21
+ def setup(self) -> None:
22
+ logger.debug("AppState setup called")
23
+ self.env = make_multiagent(id="ConnectFour-v3", render_mode="rgb_array")
24
  self.env.reset(seed=123)
 
 
25
 
26
+ self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
27
+ self.set_agent() # TODO in properties setter.
28
+ self.agent: Agent[Any, Any]
29
+
30
+ def __new__(cls):
31
+ if cls._instance is None:
32
+ cls._instance = super().__new__(cls)
33
+ cls._instance.setup()
34
+ return cls._instance
35
+
36
  def set_board(self, board: Board) -> None:
37
  set_state(env=self.env, board=board)
38
 
39
  def set_config(self, cpu_config: CpuConfig) -> None:
40
  if (
41
+ cpu_config != self.cpu_config
 
 
42
  ):
43
  self.cpu_config = cpu_config
44
  self.set_agent()
45
 
46
  def create_rollout(self) -> Agent[Any, Any]:
 
 
47
  match self.cpu_config.rollout_policy:
48
  case None:
49
  return RandomMultiAgent()
 
57
  )
58
 
59
  def set_agent(self) -> None:
 
 
 
60
  match self.cpu_config.agent_type:
61
  case AgentType.MCTS:
62
  rollout_agent = self.create_rollout()
 
75
  )
76
 
77
  def get_action(self) -> int:
 
 
78
  return self.agent.get_action(env=self.env)