Spaces:
Build error
Build error
from typing import Any | |
from litrl import make_multiagent | |
from litrl.algo.mcts.agent import MCTSAgent | |
from litrl.algo.mcts.mcts_config import MCTSConfig | |
from litrl.algo.mcts.rollout import VanillaRollout | |
from litrl.algo.sac.agent import OnnxSacDeterministicAgent | |
from litrl.common.agent import Agent, RandomMultiAgent | |
from litrl.env.connect_four import Board | |
from litrl.env.set_state import set_state | |
from src.typing import AgentType, CpuConfig, RolloutPolicy | |
class AppState: | |
def __init__(self) -> None: | |
self.env = make_multiagent(id="ConnectFour-v3", render_mode="human") | |
self.env.reset(seed=123) | |
self.agent: Agent[Any, Any] | None = None | |
self.cpu_config: CpuConfig | None = None | |
def set_board(self, board: Board) -> None: | |
set_state(env=self.env, board=board) | |
def set_config(self, cpu_config: CpuConfig) -> None: | |
if ( | |
self.agent is None | |
or self.cpu_config is None | |
or cpu_config != self.cpu_config | |
): | |
self.cpu_config = cpu_config | |
self.set_agent() | |
def create_rollout(self) -> Agent[Any, Any]: | |
if self.cpu_config is None: | |
raise ValueError("self.cpu_config is None") | |
match self.cpu_config.rollout_policy: | |
case None: | |
return RandomMultiAgent() | |
case RolloutPolicy.SAC: | |
return OnnxSacDeterministicAgent() | |
case RolloutPolicy.RANDOM: | |
return RandomMultiAgent() | |
case _: | |
raise NotImplementedError( | |
f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}" | |
) | |
def set_agent(self) -> None: | |
if self.cpu_config is None: | |
raise ValueError("self.cpu_config is None") | |
match self.cpu_config.agent_type: | |
case AgentType.MCTS: | |
rollout_agent = self.create_rollout() | |
mcts_config = MCTSConfig( | |
simulations=self.cpu_config.simulations or 50, | |
rollout_strategy=VanillaRollout(rollout_agent=rollout_agent), | |
) | |
self.agent = MCTSAgent(cfg=mcts_config) | |
case AgentType.RANDOM: | |
self.agent = RandomMultiAgent() | |
case AgentType.SAC: | |
self.agent = OnnxSacDeterministicAgent() | |
case _: | |
raise NotImplementedError( | |
f"cpu_config.name: {self.cpu_config.agent_type}" | |
) | |
def get_action(self) -> int: | |
if self.agent is None: | |
raise ValueError("self.agent is None") | |
return self.agent.get_action(env=self.env) | |