Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- src/app.py +4 -4
- src/app_state.py +31 -32
src/app.py
CHANGED
@@ -67,7 +67,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
|
|
67 |
@app.post(path="/connect_four/play", response_model=GridResponseType)
|
68 |
def endpoint_play(
|
69 |
action: int,
|
70 |
-
app_state: Annotated[AppState, Depends(dependency=
|
71 |
) -> GridResponseType:
|
72 |
response = step(app_state.env, action)
|
73 |
app_state.inform_action(action=action)
|
@@ -75,7 +75,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
|
|
75 |
|
76 |
@app.get(path="/connect_four/observe", response_model=GridResponseType)
|
77 |
def endpoint_observe(
|
78 |
-
app_state: Annotated[AppState, Depends(dependency=
|
79 |
) -> GridResponseType:
|
80 |
return observe(app_state.env)
|
81 |
|
@@ -96,7 +96,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
|
|
96 |
|
97 |
@app.get(path="/connect_four/bot_progress", response_model=float)
|
98 |
def endpoint_bot_progress(
|
99 |
-
app_state: Annotated[AppState, Depends(dependency=
|
100 |
) -> float:
|
101 |
if isinstance(app_state.agent, MCTSAgent):
|
102 |
if app_state.cpu_config.simulations is None:
|
@@ -110,7 +110,7 @@ def create_app() -> FastAPI: # noqa: C901 # TODO move to routes
|
|
110 |
|
111 |
@app.get(path="/connect_four/reset", response_model=GridResponseType)
|
112 |
def endpoint_reset(
|
113 |
-
app_state: Annotated[AppState, Depends(dependency=
|
114 |
) -> GridResponseType:
|
115 |
app_state.env.reset()
|
116 |
return observe(app_state.env)
|
|
|
67 |
@app.post(path="/connect_four/play", response_model=GridResponseType)
|
68 |
def endpoint_play(
|
69 |
action: int,
|
70 |
+
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
|
71 |
) -> GridResponseType:
|
72 |
response = step(app_state.env, action)
|
73 |
app_state.inform_action(action=action)
|
|
|
75 |
|
76 |
@app.get(path="/connect_four/observe", response_model=GridResponseType)
|
77 |
def endpoint_observe(
|
78 |
+
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
|
79 |
) -> GridResponseType:
|
80 |
return observe(app_state.env)
|
81 |
|
|
|
96 |
|
97 |
@app.get(path="/connect_four/bot_progress", response_model=float)
|
98 |
def endpoint_bot_progress(
|
99 |
+
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
|
100 |
) -> float:
|
101 |
if isinstance(app_state.agent, MCTSAgent):
|
102 |
if app_state.cpu_config.simulations is None:
|
|
|
110 |
|
111 |
@app.get(path="/connect_four/reset", response_model=GridResponseType)
|
112 |
def endpoint_reset(
|
113 |
+
app_state: Annotated[AppState, Depends(dependency=get_app_state)],
|
114 |
) -> GridResponseType:
|
115 |
app_state.env.reset()
|
116 |
return observe(app_state.env)
|
src/app_state.py
CHANGED
@@ -1,4 +1,11 @@
|
|
1 |
-
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from loguru import logger
|
4 |
|
@@ -42,42 +49,34 @@ class AppState:
|
|
42 |
logger.info("cpu_config unchanged")
|
43 |
|
44 |
def create_rollout(self) -> Agent[Any, Any]:
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
case RolloutPolicy.SAC:
|
49 |
-
return OnnxSacDeterministicMultiAgent()
|
50 |
-
case RolloutPolicy.RANDOM:
|
51 |
-
return RandomMultiAgent()
|
52 |
-
case _:
|
53 |
-
msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
|
54 |
-
raise NotImplementedError(msg)
|
55 |
|
56 |
def set_agent(self) -> None:
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
raise NotImplementedError(msg)
|
76 |
|
77 |
def get_action(self) -> int:
|
78 |
return self.agent.get_action(env=self.env)
|
79 |
|
80 |
def inform_action(self, action: int) -> None:
|
81 |
"""Update the agent's state as a result of external changes to the environment."""
|
82 |
-
if isinstance(self.agent, MCTSAgent)
|
83 |
-
self.agent.
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
try:
|
6 |
+
from typing import Self
|
7 |
+
except ImportError:
|
8 |
+
from typing_extensions import Self
|
9 |
|
10 |
from loguru import logger
|
11 |
|
|
|
49 |
logger.info("cpu_config unchanged")
|
50 |
|
51 |
def create_rollout(self) -> Agent[Any, Any]:
|
52 |
+
if self.cpu_config.rollout_policy == RolloutPolicy.SAC:
|
53 |
+
return OnnxSacDeterministicMultiAgent()
|
54 |
+
return RandomMultiAgent(np.random.default_rng(seed=123))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def set_agent(self) -> None:
|
57 |
+
if self.cpu_config.agent_type.value == AgentType.MCTS.value:
|
58 |
+
rollout_agent = self.create_rollout()
|
59 |
+
# fmt: off
|
60 |
+
mcts_config = (
|
61 |
+
MCTSConfigBuilder()
|
62 |
+
.set_simulations(self.cpu_config.simulations or 50)
|
63 |
+
.set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
|
64 |
+
).build()
|
65 |
+
# fmt: on
|
66 |
+
self.agent = MCTSAgent(cfg=mcts_config)
|
67 |
+
logger.debug("set_agent: MCTSAgent")
|
68 |
+
elif self.cpu_config.agent_type.value == AgentType.RANDOM.value:
|
69 |
+
self.agent = RandomMultiAgent()
|
70 |
+
elif self.cpu_config.agent_type.value == AgentType.SAC.value:
|
71 |
+
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
|
72 |
+
else:
|
73 |
+
msg = f"cpu_config.name: {self.cpu_config.agent_type}"
|
74 |
+
raise NotImplementedError(msg)
|
|
|
75 |
|
76 |
def get_action(self) -> int:
|
77 |
return self.agent.get_action(env=self.env)
|
78 |
|
79 |
def inform_action(self, action: int) -> None:
|
80 |
"""Update the agent's state as a result of external changes to the environment."""
|
81 |
+
if isinstance(self.agent, MCTSAgent):
|
82 |
+
self.agent.inform_action(action)
|