Spaces:
Build error
Build error
from __future__ import annotations | |
from typing import TYPE_CHECKING, Any | |
import numpy as np | |
from loguru import logger | |
if TYPE_CHECKING: | |
import sys | |
from litrl.env.connect_four import ConnectFour | |
if sys.version_info[:2] >= (3, 11): | |
from typing import Self | |
else: | |
from typing_extensions import Self | |
from litrl import make_multiagent | |
from litrl.algo.mcts.agent import MCTSAgent | |
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder | |
from litrl.algo.mcts.rollout import VanillaRollout | |
from litrl.common.agent import Agent, RandomMultiAgent | |
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent | |
from src.typing import AgentType, CpuConfig, GridResponseType, RolloutPolicy | |
class AppState: | |
_instance: Self | None = None | |
env: ConnectFour | |
cpu_config: CpuConfig | |
agent: Agent[Any, int] | None = None | |
def setup(self) -> None: | |
logger.debug("AppState setup called") | |
self.env = make_multiagent(id="connect_four", render_mode="rgb_array") | |
self.env.reset(seed=123) | |
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.MCTS, simulations=500) | |
self.set_agent(self.cpu_config) # TODO in properties setter. | |
def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034 | |
if cls._instance is None: | |
cls._instance = super().__new__(cls) | |
cls._instance.setup() | |
return cls._instance | |
def set_config(self, cpu_config: CpuConfig) -> None: | |
logger.info(f"new cpu_config: {cpu_config}") | |
if cpu_config != self.cpu_config: | |
self.set_agent(cpu_config) | |
self.cpu_config = cpu_config | |
else: | |
logger.info("cpu_config unchanged") | |
def create_rollout(self, rollout_policy: RolloutPolicy) -> Agent[Any, Any]: | |
if rollout_policy == RolloutPolicy.SAC: | |
return OnnxSacDeterministicMultiAgent() | |
return RandomMultiAgent(np.random.default_rng(seed=123)) | |
def can_reuse_mcts_computations(self, cpu_config: CpuConfig) -> bool: | |
return ( | |
self.agent is not None | |
and isinstance(self.agent, MCTSAgent) | |
and self.agent.mcts is not None | |
and self.cpu_config.agent_type == AgentType.MCTS | |
and self.cpu_config.rollout_policy != cpu_config.rollout_policy | |
) | |
def set_agent(self, cpu_config: CpuConfig) -> None: | |
if cpu_config.agent_type == AgentType.MCTS: | |
if not self.can_reuse_mcts_computations(cpu_config): | |
# fmt: off | |
mcts_config = ( | |
MCTSConfigBuilder() | |
.set_simulations(self.cpu_config.simulations or 50) | |
.set_rollout_strategy(VanillaRollout(rollout_agent=self.create_rollout(cpu_config.rollout_policy))) | |
).build() | |
# fmt: on | |
self.agent = MCTSAgent(cfg=mcts_config) | |
logger.debug("set_agent: MCTSAgent") | |
else: | |
if self.agent is None or not isinstance(self.agent, MCTSAgent) or self.agent.mcts is None: | |
raise ValueError | |
self.agent.mcts.cfg.simulations = cpu_config.simulations | |
elif cpu_config.agent_type == AgentType.RANDOM: | |
self.agent = RandomMultiAgent() | |
elif cpu_config.agent_type == AgentType.SAC: | |
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO | |
else: | |
msg = f"cpu_config.name: {cpu_config.agent_type}" | |
raise NotImplementedError(msg) | |
def get_action(self) -> int: | |
if self.agent is None: | |
raise ValueError | |
return self.agent.get_action(env=self.env) | |
def step(self, action: int) -> GridResponseType: | |
if isinstance(self.agent, MCTSAgent): | |
self.agent.step(self.env, action) | |
else: | |
self.env.step(action) | |
return self.observe() | |
def reset(self) -> GridResponseType: | |
if isinstance(self.agent, MCTSAgent): | |
self.agent.reset(self.env) | |
else: | |
self.env.reset() | |
return self.observe() | |
def observe(self) -> None: | |
obs = self.env.observe("player_1") | |
return GridResponseType( # type: ignore[no-any-return] | |
grid=obs["observation"].tolist(), | |
done=bool( | |
self.env.terminations[self.env.agent_selection] or self.env.truncations[self.env.agent_selection], | |
), # TODO why needed? | |
) | |