from typing import Any from litrl import make_multiagent from litrl.algo.mcts.agent import MCTSAgent from litrl.algo.mcts.mcts_config import MCTSConfig from litrl.algo.mcts.rollout import VanillaRollout from litrl.algo.sac.agent import OnnxSacDeterministicAgent from litrl.common.agent import Agent, RandomMultiAgent from litrl.env.connect_four import Board from litrl.env.set_state import set_state from src.typing import AgentType, CpuConfig, RolloutPolicy class AppState: def __init__(self) -> None: self.env = make_multiagent(id="ConnectFour-v3", render_mode="human") self.env.reset(seed=123) self.agent: Agent[Any, Any] | None = None self.cpu_config: CpuConfig | None = None def set_board(self, board: Board) -> None: set_state(env=self.env, board=board) def set_config(self, cpu_config: CpuConfig) -> None: if ( self.agent is None or self.cpu_config is None or cpu_config != self.cpu_config ): self.cpu_config = cpu_config self.set_agent() def create_rollout(self) -> Agent[Any, Any]: if self.cpu_config is None: raise ValueError("self.cpu_config is None") match self.cpu_config.rollout_policy: case None: return RandomMultiAgent() case RolloutPolicy.SAC: return OnnxSacDeterministicAgent() case RolloutPolicy.RANDOM: return RandomMultiAgent() case _: raise NotImplementedError( f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}" ) def set_agent(self) -> None: if self.cpu_config is None: raise ValueError("self.cpu_config is None") match self.cpu_config.agent_type: case AgentType.MCTS: rollout_agent = self.create_rollout() mcts_config = MCTSConfig( simulations=self.cpu_config.simulations or 50, rollout_strategy=VanillaRollout(rollout_agent=rollout_agent), ) self.agent = MCTSAgent(cfg=mcts_config) case AgentType.RANDOM: self.agent = RandomMultiAgent() case AgentType.SAC: self.agent = OnnxSacDeterministicAgent() case _: raise NotImplementedError( f"cpu_config.name: {self.cpu_config.agent_type}" ) def get_action(self) -> int: if self.agent is None: raise ValueError("self.agent is None") return self.agent.get_action(env=self.env)