Spaces:
Build error
Build error
File size: 3,296 Bytes
5b344d4 5cd7fc9 5b344d4 5cd7fc9 5b344d4 bafb458 5b344d4 7011484 bafb458 7011484 bafb458 7011484 76c534f 7011484 bafb458 6e7d45d 7011484 6e7d45d 7011484 bafb458 6e7d45d acf3b96 6e7d45d 5b344d4 6e7d45d bafb458 acf3b96 7011484 bafb458 acf3b96 bafb458 5cd7fc9 bafb458 5cd7fc9 bafb458 acf3b96 5b344d4 acf3b96 5cd7fc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from loguru import logger
if TYPE_CHECKING:
import sys
from litrl.env.connect_four import ConnectFour
if sys.version_info[:2] >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
from src.typing import AgentType, CpuConfig, RolloutPolicy
class AppState:
_instance: Self | None = None
env: ConnectFour
cpu_config: CpuConfig
agent: Agent[Any, int]
def setup(self) -> None:
logger.debug("AppState setup called")
self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
self.env.reset(seed=123)
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
self.set_agent() # TODO in properties setter.
self.agent: Agent[Any, Any]
def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.setup()
return cls._instance
def set_config(self, cpu_config: CpuConfig) -> None:
logger.info(f"new cpu_config: {cpu_config}")
if cpu_config != self.cpu_config:
self.cpu_config = cpu_config
self.set_agent()
else:
logger.info("cpu_config unchanged")
def create_rollout(self) -> Agent[Any, Any]:
if self.cpu_config.rollout_policy == RolloutPolicy.SAC:
return OnnxSacDeterministicMultiAgent()
return RandomMultiAgent(np.random.default_rng(seed=123))
def set_agent(self) -> None:
if self.cpu_config.agent_type.value == AgentType.MCTS.value:
rollout_agent = self.create_rollout()
# fmt: off
mcts_config = (
MCTSConfigBuilder()
.set_simulations(self.cpu_config.simulations or 50)
.set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
).build()
# fmt: on
self.agent = MCTSAgent(cfg=mcts_config)
logger.debug("set_agent: MCTSAgent")
elif self.cpu_config.agent_type.value == AgentType.RANDOM.value:
self.agent = RandomMultiAgent()
elif self.cpu_config.agent_type.value == AgentType.SAC.value:
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
else:
msg = f"cpu_config.name: {self.cpu_config.agent_type}"
raise NotImplementedError(msg)
def get_action(self) -> int:
return self.agent.get_action(env=self.env)
def inform_reset(self) -> None:
if isinstance(self.agent, MCTSAgent):
self.agent.inform_reset()
def inform_action(self, action: int) -> None:
"""Update the agent's state as a result of external changes to the environment."""
if isinstance(self.agent, MCTSAgent):
self.agent.inform_action(action)
|