from typing import Any, Self from litrl import make_multiagent from litrl.algo.mcts.agent import MCTSAgent from litrl.algo.mcts.mcts_config import MCTSConfig from litrl.algo.mcts.rollout import VanillaRollout from litrl.algo.sac.agent import OnnxSacDeterministicAgent from litrl.common.agent import Agent, RandomMultiAgent from litrl.env.connect_four import Board from litrl.env.set_state import set_state from loguru import logger from src.typing import AgentType, CpuConfig, RolloutPolicy from litrl.env.connect_four import ConnectFour class AppState: _instance: Self | None = None env: ConnectFour cpu_config: CpuConfig agent: Agent[Any, Any] def setup(self) -> None: logger.debug("AppState setup called") self.env = make_multiagent(id="ConnectFour-v3", render_mode="rgb_array") self.env.reset(seed=123) self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM) self.set_agent() # TODO in properties setter. self.agent: Agent[Any, Any] def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance.setup() return cls._instance def set_board(self, board: Board) -> None: set_state(env=self.env, board=board) def set_config(self, cpu_config: CpuConfig) -> None: if ( cpu_config != self.cpu_config ): self.cpu_config = cpu_config self.set_agent() def create_rollout(self) -> Agent[Any, Any]: match self.cpu_config.rollout_policy: case None: return RandomMultiAgent() case RolloutPolicy.SAC: return OnnxSacDeterministicAgent() case RolloutPolicy.RANDOM: return RandomMultiAgent() case _: raise NotImplementedError( f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}" ) def set_agent(self) -> None: match self.cpu_config.agent_type: case AgentType.MCTS: rollout_agent = self.create_rollout() mcts_config = MCTSConfig( simulations=self.cpu_config.simulations or 50, rollout_strategy=VanillaRollout(rollout_agent=rollout_agent), ) self.agent = MCTSAgent(cfg=mcts_config) case AgentType.RANDOM: self.agent = RandomMultiAgent() case AgentType.SAC: self.agent = OnnxSacDeterministicAgent() case _: raise NotImplementedError( f"cpu_config.name: {self.cpu_config.agent_type}" ) def get_action(self) -> int: return self.agent.get_action(env=self.env)