c-gohlke commited on
Commit
302ae2f
·
verified ·
1 Parent(s): 6033379

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. src/app.py +7 -39
  2. src/app_state.py +56 -28
  3. src/typing.py +14 -1
src/app.py CHANGED
@@ -1,6 +1,6 @@
1
  import sys
2
  from pathlib import Path
3
- from typing import Any, Generator, List
4
 
5
  if sys.version_info[:2] >= (3, 11):
6
  from typing import Annotated
@@ -13,16 +13,14 @@ from fastapi.middleware.cors import CORSMiddleware
13
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
14
  from gymnasium.wrappers.record_video import RecordVideo
15
  from loguru import logger
16
- from pydantic import BaseModel
17
 
18
  from litrl.algo.mcts.agent import MCTSAgent
19
  from litrl.common.agent import RandomAgent
20
- from litrl.env.connect_four import Board, ConnectFour
21
  from litrl.env.make import make
22
  from litrl.env.typing import GymId
23
  from src.app_state import AppState
24
  from src.huggingface.huggingface_client import HuggingFaceClient
25
- from src.typing import CpuConfig
26
 
27
 
28
  def stream_mp4(mp4_path: Path) -> StreamingResponse:
@@ -33,35 +31,10 @@ def stream_mp4(mp4_path: Path) -> StreamingResponse:
33
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
34
 
35
 
36
- ObservationType = List[Board]
37
-
38
-
39
- class GridResponseType(BaseModel):
40
- grid: ObservationType
41
- done: bool
42
-
43
-
44
- class BotResponseType(GridResponseType):
45
- action: int
46
-
47
-
48
  def get_app_state() -> AppState:
49
  return AppState()
50
 
51
 
52
- def step(env: ConnectFour, action: int) -> GridResponseType:
53
- env.step(action)
54
- return observe(env)
55
-
56
-
57
- def observe(env: ConnectFour) -> GridResponseType:
58
- obs = env.observe("player_1")
59
- return GridResponseType(
60
- grid=obs["observation"].tolist(),
61
- done=bool(env.terminations[env.agent_selection] or env.truncations[env.agent_selection]), # TODO why needed?
62
- )
63
-
64
-
65
  def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
66
  app = FastAPI()
67
 
@@ -74,15 +47,13 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
74
  action: int,
75
  app_state: Annotated[AppState, Depends(dependency=get_app_state)],
76
  ) -> GridResponseType:
77
- response = step(app_state.env, action)
78
- app_state.inform_action(action=action)
79
- return response
80
 
81
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
82
  def endpoint_observe(
83
  app_state: Annotated[AppState, Depends(dependency=get_app_state)],
84
  ) -> GridResponseType:
85
- return observe(app_state.env)
86
 
87
  @app.post(path="/connect_four/bot_play", response_model=BotResponseType)
88
  def endpoint_bot_play(
@@ -91,8 +62,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
91
  ) -> BotResponseType:
92
  app_state.set_config(cpu_config)
93
  action = app_state.get_action()
94
- response = step(app_state.env, action)
95
- app_state.inform_action(action=action)
96
  return BotResponseType(
97
  grid=response.grid,
98
  done=response.done,
@@ -107,7 +77,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
107
  if app_state.cpu_config.simulations is None:
108
  raise ValueError
109
  if app_state.agent.mcts is None:
110
- raise ValueError
111
  return float(
112
  app_state.agent.mcts.root.visits / app_state.cpu_config.simulations,
113
  ) # TODO why not recognized as float?
@@ -117,9 +87,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
117
  def endpoint_reset(
118
  app_state: Annotated[AppState, Depends(dependency=get_app_state)],
119
  ) -> GridResponseType:
120
- app_state.env.reset()
121
- app_state.inform_reset()
122
- return observe(app_state.env)
123
 
124
  @app.get(path="/get_huggingface_video")
125
  def endpoint_get_huggingface_video(
 
1
  import sys
2
  from pathlib import Path
3
+ from typing import Any, Generator
4
 
5
  if sys.version_info[:2] >= (3, 11):
6
  from typing import Annotated
 
13
  from fastapi.responses import JSONResponse, RedirectResponse, StreamingResponse
14
  from gymnasium.wrappers.record_video import RecordVideo
15
  from loguru import logger
 
16
 
17
  from litrl.algo.mcts.agent import MCTSAgent
18
  from litrl.common.agent import RandomAgent
 
19
  from litrl.env.make import make
20
  from litrl.env.typing import GymId
21
  from src.app_state import AppState
22
  from src.huggingface.huggingface_client import HuggingFaceClient
23
+ from src.typing import BotResponseType, CpuConfig, GridResponseType
24
 
25
 
26
  def stream_mp4(mp4_path: Path) -> StreamingResponse:
 
31
  return StreamingResponse(content=iter_file(), media_type="video/mp4")
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def get_app_state() -> AppState:
35
  return AppState()
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
39
  app = FastAPI()
40
 
 
47
  action: int,
48
  app_state: Annotated[AppState, Depends(dependency=get_app_state)],
49
  ) -> GridResponseType:
50
+ return app_state.step(action)
 
 
51
 
52
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
53
  def endpoint_observe(
54
  app_state: Annotated[AppState, Depends(dependency=get_app_state)],
55
  ) -> GridResponseType:
56
+ return app_state.observe()
57
 
58
  @app.post(path="/connect_four/bot_play", response_model=BotResponseType)
59
  def endpoint_bot_play(
 
62
  ) -> BotResponseType:
63
  app_state.set_config(cpu_config)
64
  action = app_state.get_action()
65
+ response = app_state.step(action)
 
66
  return BotResponseType(
67
  grid=response.grid,
68
  done=response.done,
 
77
  if app_state.cpu_config.simulations is None:
78
  raise ValueError
79
  if app_state.agent.mcts is None:
80
+ return 1.0
81
  return float(
82
  app_state.agent.mcts.root.visits / app_state.cpu_config.simulations,
83
  ) # TODO why not recognized as float?
 
87
  def endpoint_reset(
88
  app_state: Annotated[AppState, Depends(dependency=get_app_state)],
89
  ) -> GridResponseType:
90
+ return app_state.reset()
 
 
91
 
92
  @app.get(path="/get_huggingface_video")
93
  def endpoint_get_huggingface_video(
src/app_state.py CHANGED
@@ -21,23 +21,22 @@ from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
21
  from litrl.algo.mcts.rollout import VanillaRollout
22
  from litrl.common.agent import Agent, RandomMultiAgent
23
  from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
24
- from src.typing import AgentType, CpuConfig, RolloutPolicy
25
 
26
 
27
  class AppState:
28
  _instance: Self | None = None
29
  env: ConnectFour
30
  cpu_config: CpuConfig
31
- agent: Agent[Any, int]
32
 
33
  def setup(self) -> None:
34
  logger.debug("AppState setup called")
35
  self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
36
  self.env.reset(seed=123)
37
 
38
- self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
39
- self.set_agent() # TODO in properties setter.
40
- self.agent: Agent[Any, Any]
41
 
42
  def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
43
  if cls._instance is None:
@@ -48,44 +47,73 @@ class AppState:
48
  def set_config(self, cpu_config: CpuConfig) -> None:
49
  logger.info(f"new cpu_config: {cpu_config}")
50
  if cpu_config != self.cpu_config:
 
51
  self.cpu_config = cpu_config
52
- self.set_agent()
53
  else:
54
  logger.info("cpu_config unchanged")
55
 
56
- def create_rollout(self) -> Agent[Any, Any]:
57
- if self.cpu_config.rollout_policy == RolloutPolicy.SAC:
58
  return OnnxSacDeterministicMultiAgent()
59
  return RandomMultiAgent(np.random.default_rng(seed=123))
60
 
61
- def set_agent(self) -> None:
62
- if self.cpu_config.agent_type.value == AgentType.MCTS.value:
63
- rollout_agent = self.create_rollout()
64
- # fmt: off
65
- mcts_config = (
66
- MCTSConfigBuilder()
67
- .set_simulations(self.cpu_config.simulations or 50)
68
- .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
69
- ).build()
70
- # fmt: on
71
- self.agent = MCTSAgent(cfg=mcts_config)
72
- logger.debug("set_agent: MCTSAgent")
73
- elif self.cpu_config.agent_type.value == AgentType.RANDOM.value:
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  self.agent = RandomMultiAgent()
75
- elif self.cpu_config.agent_type.value == AgentType.SAC.value:
76
  self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
77
  else:
78
- msg = f"cpu_config.name: {self.cpu_config.agent_type}"
79
  raise NotImplementedError(msg)
80
 
81
  def get_action(self) -> int:
 
 
82
  return self.agent.get_action(env=self.env)
83
 
84
- def inform_reset(self) -> None:
85
  if isinstance(self.agent, MCTSAgent):
86
- self.agent.inform_reset()
 
 
 
87
 
88
- def inform_action(self, action: int) -> None:
89
- """Update the agent's state as a result of external changes to the environment."""
90
  if isinstance(self.agent, MCTSAgent):
91
- self.agent.inform_action(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from litrl.algo.mcts.rollout import VanillaRollout
22
  from litrl.common.agent import Agent, RandomMultiAgent
23
  from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
24
+ from src.typing import AgentType, CpuConfig, GridResponseType, RolloutPolicy
25
 
26
 
27
  class AppState:
28
  _instance: Self | None = None
29
  env: ConnectFour
30
  cpu_config: CpuConfig
31
+ agent: Agent[Any, int] | None = None
32
 
33
  def setup(self) -> None:
34
  logger.debug("AppState setup called")
35
  self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
36
  self.env.reset(seed=123)
37
 
38
+ self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.MCTS, simulations=500)
39
+ self.set_agent(self.cpu_config) # TODO in properties setter.
 
40
 
41
  def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
42
  if cls._instance is None:
 
47
  def set_config(self, cpu_config: CpuConfig) -> None:
48
  logger.info(f"new cpu_config: {cpu_config}")
49
  if cpu_config != self.cpu_config:
50
+ self.set_agent(cpu_config)
51
  self.cpu_config = cpu_config
 
52
  else:
53
  logger.info("cpu_config unchanged")
54
 
55
+ def create_rollout(self, rollout_policy: RolloutPolicy) -> Agent[Any, Any]:
56
+ if rollout_policy == RolloutPolicy.SAC:
57
  return OnnxSacDeterministicMultiAgent()
58
  return RandomMultiAgent(np.random.default_rng(seed=123))
59
 
60
+ def can_reuse_mcts_computations(self, cpu_config: CpuConfig) -> bool:
61
+ return (
62
+ self.agent is not None
63
+ and isinstance(self.agent, MCTSAgent)
64
+ and self.agent.mcts is not None
65
+ and self.cpu_config.agent_type == AgentType.MCTS
66
+ and self.cpu_config.rollout_policy != cpu_config.rollout_policy
67
+ )
68
+
69
+ def set_agent(self, cpu_config: CpuConfig) -> None:
70
+ if cpu_config.agent_type == AgentType.MCTS:
71
+ if not self.can_reuse_mcts_computations(cpu_config):
72
+ # fmt: off
73
+ mcts_config = (
74
+ MCTSConfigBuilder()
75
+ .set_simulations(self.cpu_config.simulations or 50)
76
+ .set_rollout_strategy(VanillaRollout(rollout_agent=self.create_rollout(cpu_config.rollout_policy)))
77
+ ).build()
78
+ # fmt: on
79
+ self.agent = MCTSAgent(cfg=mcts_config)
80
+ logger.debug("set_agent: MCTSAgent")
81
+ else:
82
+ if self.agent is None or not isinstance(self.agent, MCTSAgent) or self.agent.mcts is None:
83
+ raise ValueError
84
+ self.agent.mcts.cfg.simulations = cpu_config.simulations
85
+ elif cpu_config.agent_type == AgentType.RANDOM:
86
  self.agent = RandomMultiAgent()
87
+ elif cpu_config.agent_type == AgentType.SAC:
88
  self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
89
  else:
90
+ msg = f"cpu_config.name: {cpu_config.agent_type}"
91
  raise NotImplementedError(msg)
92
 
93
  def get_action(self) -> int:
94
+ if self.agent is None:
95
+ raise ValueError
96
  return self.agent.get_action(env=self.env)
97
 
98
+ def step(self, action: int) -> GridResponseType:
99
  if isinstance(self.agent, MCTSAgent):
100
+ self.agent.step(self.env, action)
101
+ else:
102
+ self.env.step(action)
103
+ return self.observe()
104
 
105
+ def reset(self) -> GridResponseType:
 
106
  if isinstance(self.agent, MCTSAgent):
107
+ self.agent.reset(self.env)
108
+ else:
109
+ self.env.reset()
110
+ return self.observe()
111
+
112
+ def observe(self) -> None:
113
+ obs = self.env.observe("player_1")
114
+ return GridResponseType( # type: ignore[no-any-return]
115
+ grid=obs["observation"].tolist(),
116
+ done=bool(
117
+ self.env.terminations[self.env.agent_selection] or self.env.truncations[self.env.agent_selection],
118
+ ), # TODO why needed?
119
+ )
src/typing.py CHANGED
@@ -1,10 +1,14 @@
1
  from __future__ import annotations
2
 
3
  import enum
4
- from typing import Optional
5
 
6
  from pydantic import BaseModel
7
 
 
 
 
 
8
 
9
  class AgentType(enum.Enum):
10
  RANDOM = "random"
@@ -21,3 +25,12 @@ class CpuConfig(BaseModel):
21
  agent_type: AgentType
22
  simulations: Optional[int] = None # noqa: UP007
23
  rollout_policy: Optional[RolloutPolicy] = None # noqa: UP007
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import enum
4
+ from typing import List, Optional
5
 
6
  from pydantic import BaseModel
7
 
8
+ from litrl.env.connect_four import Board
9
+
10
+ ObservationType = List[Board]
11
+
12
 
13
  class AgentType(enum.Enum):
14
  RANDOM = "random"
 
25
  agent_type: AgentType
26
  simulations: Optional[int] = None # noqa: UP007
27
  rollout_policy: Optional[RolloutPolicy] = None # noqa: UP007
28
+
29
+
30
+ class GridResponseType(BaseModel):
31
+ grid: ObservationType
32
+ done: bool
33
+
34
+
35
+ class BotResponseType(GridResponseType):
36
+ action: int