c-gohlke commited on
Commit
5cd7fc9
·
verified ·
1 Parent(s): aba3e20

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. src/app.py +4 -4
  2. src/app_state.py +31 -32
src/app.py CHANGED
@@ -67,7 +67,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
67
  @app.post(path="/connect_four/play", response_model=GridResponseType)
68
  def endpoint_play(
69
  action: int,
70
- app_state: Annotated[AppState, Depends(dependency=AppState)],
71
  ) -> GridResponseType:
72
  response = step(app_state.env, action)
73
  app_state.inform_action(action=action)
@@ -75,7 +75,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
75
 
76
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
77
  def endpoint_observe(
78
- app_state: Annotated[AppState, Depends(dependency=AppState)],
79
  ) -> GridResponseType:
80
  return observe(app_state.env)
81
 
@@ -96,7 +96,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
96
 
97
  @app.get(path="/connect_four/bot_progress", response_model=float)
98
  def endpoint_bot_progress(
99
- app_state: Annotated[AppState, Depends(dependency=AppState)],
100
  ) -> float:
101
  if isinstance(app_state.agent, MCTSAgent):
102
  if app_state.cpu_config.simulations is None:
@@ -110,7 +110,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
110
 
111
  @app.get(path="/connect_four/reset", response_model=GridResponseType)
112
  def endpoint_reset(
113
- app_state: Annotated[AppState, Depends(dependency=AppState)],
114
  ) -> GridResponseType:
115
  app_state.env.reset()
116
  return observe(app_state.env)
 
67
  @app.post(path="/connect_four/play", response_model=GridResponseType)
68
  def endpoint_play(
69
  action: int,
70
+ app_state: Annotated[AppState, Depends(dependency=get_app_state)],
71
  ) -> GridResponseType:
72
  response = step(app_state.env, action)
73
  app_state.inform_action(action=action)
 
75
 
76
  @app.get(path="/connect_four/observe", response_model=GridResponseType)
77
  def endpoint_observe(
78
+ app_state: Annotated[AppState, Depends(dependency=get_app_state)],
79
  ) -> GridResponseType:
80
  return observe(app_state.env)
81
 
 
96
 
97
  @app.get(path="/connect_four/bot_progress", response_model=float)
98
  def endpoint_bot_progress(
99
+ app_state: Annotated[AppState, Depends(dependency=get_app_state)],
100
  ) -> float:
101
  if isinstance(app_state.agent, MCTSAgent):
102
  if app_state.cpu_config.simulations is None:
 
110
 
111
  @app.get(path="/connect_four/reset", response_model=GridResponseType)
112
  def endpoint_reset(
113
+ app_state: Annotated[AppState, Depends(dependency=get_app_state)],
114
  ) -> GridResponseType:
115
  app_state.env.reset()
116
  return observe(app_state.env)
src/app_state.py CHANGED
@@ -1,4 +1,11 @@
1
- from typing import Any, Self
 
 
 
 
 
 
 
2
 
3
  from loguru import logger
4
 
@@ -42,42 +49,34 @@ class AppState:
42
  logger.info("cpu_config unchanged")
43
 
44
  def create_rollout(self) -> Agent[Any, Any]:
45
- match self.cpu_config.rollout_policy:
46
- case None:
47
- return RandomMultiAgent()
48
- case RolloutPolicy.SAC:
49
- return OnnxSacDeterministicMultiAgent()
50
- case RolloutPolicy.RANDOM:
51
- return RandomMultiAgent()
52
- case _:
53
- msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
54
- raise NotImplementedError(msg)
55
 
56
  def set_agent(self) -> None:
57
- match self.cpu_config.agent_type.value:
58
- case AgentType.MCTS.value:
59
- rollout_agent = self.create_rollout()
60
- # fmt: off
61
- mcts_config = (
62
- MCTSConfigBuilder()
63
- .set_simulations(self.cpu_config.simulations or 50)
64
- .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
65
- ).build()
66
- # fmt: on
67
- self.agent = MCTSAgent(cfg=mcts_config)
68
- logger.debug("set_agent: MCTSAgent")
69
- case AgentType.RANDOM.value:
70
- self.agent = RandomMultiAgent()
71
- case AgentType.SAC.value:
72
- self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
73
- case _:
74
- msg = f"cpu_config.name: {self.cpu_config.agent_type}"
75
- raise NotImplementedError(msg)
76
 
77
  def get_action(self) -> int:
78
  return self.agent.get_action(env=self.env)
79
 
80
  def inform_action(self, action: int) -> None:
81
  """Update the agent's state as a result of external changes to the environment."""
82
- if isinstance(self.agent, MCTSAgent) and self.agent.mcts is not None:
83
- self.agent.mcts.update_root(action)
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+
5
+ try:
6
+ from typing import Self
7
+ except ImportError:
8
+ from typing_extensions import Self
9
 
10
  from loguru import logger
11
 
 
49
  logger.info("cpu_config unchanged")
50
 
51
  def create_rollout(self) -> Agent[Any, Any]:
52
+ if self.cpu_config.rollout_policy == RolloutPolicy.SAC:
53
+ return OnnxSacDeterministicMultiAgent()
54
+ return RandomMultiAgent(np.random.default_rng(seed=123))
 
 
 
 
 
 
 
55
 
56
  def set_agent(self) -> None:
57
+ if self.cpu_config.agent_type.value == AgentType.MCTS.value:
58
+ rollout_agent = self.create_rollout()
59
+ # fmt: off
60
+ mcts_config = (
61
+ MCTSConfigBuilder()
62
+ .set_simulations(self.cpu_config.simulations or 50)
63
+ .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
64
+ ).build()
65
+ # fmt: on
66
+ self.agent = MCTSAgent(cfg=mcts_config)
67
+ logger.debug("set_agent: MCTSAgent")
68
+ elif self.cpu_config.agent_type.value == AgentType.RANDOM.value:
69
+ self.agent = RandomMultiAgent()
70
+ elif self.cpu_config.agent_type.value == AgentType.SAC.value:
71
+ self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
72
+ else:
73
+ msg = f"cpu_config.name: {self.cpu_config.agent_type}"
74
+ raise NotImplementedError(msg)
 
75
 
76
  def get_action(self) -> int:
77
  return self.agent.get_action(env=self.env)
78
 
79
  def inform_action(self, action: int) -> None:
80
  """Update the agent's state as a result of external changes to the environment."""
81
+ if isinstance(self.agent, MCTSAgent):
82
+ self.agent.inform_action(action)