Spaces:
Build error
Build error
File size: 4,499 Bytes
5b344d4 5cd7fc9 5b344d4 5cd7fc9 5b344d4 bafb458 5b344d4 7011484 bafb458 7011484 bafb458 7011484 302ae2f 7011484 bafb458 6e7d45d 302ae2f 6e7d45d 7011484 bafb458 302ae2f 6e7d45d 5b344d4 6e7d45d bafb458 acf3b96 7011484 302ae2f bafb458 acf3b96 bafb458 302ae2f 5cd7fc9 bafb458 302ae2f 5cd7fc9 302ae2f 5cd7fc9 302ae2f 5cd7fc9 bafb458 302ae2f bafb458 acf3b96 302ae2f 5b344d4 302ae2f 5b344d4 302ae2f 5cd7fc9 302ae2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from loguru import logger
if TYPE_CHECKING:
import sys
from litrl.env.connect_four import ConnectFour
if sys.version_info[:2] >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
from src.typing import AgentType, CpuConfig, GridResponseType, RolloutPolicy
class AppState:
_instance: Self | None = None
env: ConnectFour
cpu_config: CpuConfig
agent: Agent[Any, int] | None = None
def setup(self) -> None:
logger.debug("AppState setup called")
self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
self.env.reset(seed=123)
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.MCTS, simulations=500)
self.set_agent(self.cpu_config) # TODO in properties setter.
def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.setup()
return cls._instance
def set_config(self, cpu_config: CpuConfig) -> None:
logger.info(f"new cpu_config: {cpu_config}")
if cpu_config != self.cpu_config:
self.set_agent(cpu_config)
self.cpu_config = cpu_config
else:
logger.info("cpu_config unchanged")
def create_rollout(self, rollout_policy: RolloutPolicy) -> Agent[Any, Any]:
if rollout_policy == RolloutPolicy.SAC:
return OnnxSacDeterministicMultiAgent()
return RandomMultiAgent(np.random.default_rng(seed=123))
def can_reuse_mcts_computations(self, cpu_config: CpuConfig) -> bool:
return (
self.agent is not None
and isinstance(self.agent, MCTSAgent)
and self.agent.mcts is not None
and self.cpu_config.agent_type == AgentType.MCTS
and self.cpu_config.rollout_policy != cpu_config.rollout_policy
)
def set_agent(self, cpu_config: CpuConfig) -> None:
if cpu_config.agent_type == AgentType.MCTS:
if not self.can_reuse_mcts_computations(cpu_config):
# fmt: off
mcts_config = (
MCTSConfigBuilder()
.set_simulations(self.cpu_config.simulations or 50)
.set_rollout_strategy(VanillaRollout(rollout_agent=self.create_rollout(cpu_config.rollout_policy)))
).build()
# fmt: on
self.agent = MCTSAgent(cfg=mcts_config)
logger.debug("set_agent: MCTSAgent")
else:
if self.agent is None or not isinstance(self.agent, MCTSAgent) or self.agent.mcts is None:
raise ValueError
self.agent.mcts.cfg.simulations = cpu_config.simulations
elif cpu_config.agent_type == AgentType.RANDOM:
self.agent = RandomMultiAgent()
elif cpu_config.agent_type == AgentType.SAC:
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
else:
msg = f"cpu_config.name: {cpu_config.agent_type}"
raise NotImplementedError(msg)
def get_action(self) -> int:
if self.agent is None:
raise ValueError
return self.agent.get_action(env=self.env)
def step(self, action: int) -> GridResponseType:
if isinstance(self.agent, MCTSAgent):
self.agent.step(self.env, action)
else:
self.env.step(action)
return self.observe()
def reset(self) -> GridResponseType:
if isinstance(self.agent, MCTSAgent):
self.agent.reset(self.env)
else:
self.env.reset()
return self.observe()
def observe(self) -> None:
obs = self.env.observe("player_1")
return GridResponseType( # type: ignore[no-any-return]
grid=obs["observation"].tolist(),
done=bool(
self.env.terminations[self.env.agent_selection] or self.env.truncations[self.env.agent_selection],
), # TODO why needed?
)
|