from __future__ import annotations from typing import TYPE_CHECKING, Any import numpy as np from loguru import logger if TYPE_CHECKING: import sys from litrl.env.connect_four import ConnectFour if sys.version_info[:2] >= (3, 11): from typing import Self else: from typing_extensions import Self from litrl import make_multiagent from litrl.algo.mcts.agent import MCTSAgent from litrl.algo.mcts.mcts_config import MCTSConfigBuilder from litrl.algo.mcts.rollout import VanillaRollout from litrl.common.agent import Agent, RandomMultiAgent from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent from src.typing import AgentType, CpuConfig, GridResponseType, RolloutPolicy class AppState: _instance: Self | None = None env: ConnectFour cpu_config: CpuConfig agent: Agent[Any, int] | None = None def setup(self) -> None: logger.debug("AppState setup called") self.env = make_multiagent(id="connect_four", render_mode="rgb_array") self.env.reset(seed=123) self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.MCTS, simulations=500) self.set_agent(self.cpu_config) # TODO in properties setter. def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034 if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.setup() return cls._instance def set_config(self, cpu_config: CpuConfig) -> None: logger.info(f"new cpu_config: {cpu_config}") if cpu_config != self.cpu_config: self.set_agent(cpu_config) self.cpu_config = cpu_config else: logger.info("cpu_config unchanged") def create_rollout(self, rollout_policy: RolloutPolicy) -> Agent[Any, Any]: if rollout_policy == RolloutPolicy.SAC: return OnnxSacDeterministicMultiAgent() return RandomMultiAgent(np.random.default_rng(seed=123)) def can_reuse_mcts_computations(self, cpu_config: CpuConfig) -> bool: return ( self.agent is not None and isinstance(self.agent, MCTSAgent) and self.agent.mcts is not None and self.cpu_config.agent_type == AgentType.MCTS and self.cpu_config.rollout_policy != cpu_config.rollout_policy ) def set_agent(self, cpu_config: CpuConfig) -> None: if cpu_config.agent_type == AgentType.MCTS: if not self.can_reuse_mcts_computations(cpu_config): # fmt: off mcts_config = ( MCTSConfigBuilder() .set_simulations(self.cpu_config.simulations or 50) .set_rollout_strategy(VanillaRollout(rollout_agent=self.create_rollout(cpu_config.rollout_policy))) ).build() # fmt: on self.agent = MCTSAgent(cfg=mcts_config) logger.debug("set_agent: MCTSAgent") else: if self.agent is None or not isinstance(self.agent, MCTSAgent) or self.agent.mcts is None: raise ValueError self.agent.mcts.cfg.simulations = cpu_config.simulations elif cpu_config.agent_type == AgentType.RANDOM: self.agent = RandomMultiAgent() elif cpu_config.agent_type == AgentType.SAC: self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO else: msg = f"cpu_config.name: {cpu_config.agent_type}" raise NotImplementedError(msg) def get_action(self) -> int: if self.agent is None: raise ValueError return self.agent.get_action(env=self.env) def step(self, action: int) -> GridResponseType: if isinstance(self.agent, MCTSAgent): self.agent.step(self.env, action) else: self.env.step(action) return self.observe() def reset(self) -> GridResponseType: if isinstance(self.agent, MCTSAgent): self.agent.reset(self.env) else: self.env.reset() return self.observe() def observe(self) -> None: obs = self.env.observe("player_1") return GridResponseType( # type: ignore[no-any-return] grid=obs["observation"].tolist(), done=bool( self.env.terminations[self.env.agent_selection] or self.env.truncations[self.env.agent_selection], ), # TODO why needed? )