from typing import Any, Self from loguru import logger from litrl import make_multiagent from litrl.algo.mcts.agent import MCTSAgent from litrl.algo.mcts.mcts_config import MCTSConfigBuilder from litrl.algo.mcts.rollout import VanillaRollout from litrl.common.agent import Agent, RandomMultiAgent from litrl.env.connect_four import ConnectFour from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent from src.typing import AgentType, CpuConfig, RolloutPolicy class AppState: _instance: Self | None = None env: ConnectFour cpu_config: CpuConfig agent: Agent[Any, int] def setup(self) -> None: logger.debug("AppState setup called") self.env = make_multiagent(id="connect_four", render_mode="rgb_array") self.env.reset(seed=123) self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM) self.set_agent() # TODO in properties setter. self.agent: Agent[Any, Any] def __new__(cls: type["AppState"]) -> "AppState": if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.setup() return cls._instance def set_config(self, cpu_config: CpuConfig) -> None: logger.info(f"new cpu_config: {cpu_config}") if cpu_config != self.cpu_config: self.cpu_config = cpu_config self.set_agent() else: logger.info("cpu_config unchanged") def create_rollout(self) -> Agent[Any, Any]: match self.cpu_config.rollout_policy: case None: return RandomMultiAgent() case RolloutPolicy.SAC: return OnnxSacDeterministicMultiAgent() case RolloutPolicy.RANDOM: return RandomMultiAgent() case _: msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}" raise NotImplementedError(msg) def set_agent(self) -> None: match self.cpu_config.agent_type.value: case AgentType.MCTS.value: rollout_agent = self.create_rollout() # fmt: off mcts_config = ( MCTSConfigBuilder() .set_simulations(self.cpu_config.simulations or 50) .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent)) ).build() # fmt: on self.agent = MCTSAgent(cfg=mcts_config) logger.debug("set_agent: MCTSAgent") case AgentType.RANDOM.value: self.agent = RandomMultiAgent() case AgentType.SAC.value: self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO case _: msg = f"cpu_config.name: {self.cpu_config.agent_type}" raise NotImplementedError(msg) def get_action(self) -> int: return self.agent.get_action(env=self.env) def inform_action(self, action: int) -> None: """Update the agent's state as a result of external changes to the environment.""" if isinstance(self.agent, MCTSAgent) and self.agent.mcts is not None: self.agent.mcts.update_root(action)