LitRL-Inference / src /app_state.py
c-gohlke's picture
Upload folder using huggingface_hub
302ae2f verified
raw
history blame
4.5 kB
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from loguru import logger
if TYPE_CHECKING:
import sys
from litrl.env.connect_four import ConnectFour
if sys.version_info[:2] >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
from src.typing import AgentType, CpuConfig, GridResponseType, RolloutPolicy
class AppState:
_instance: Self | None = None
env: ConnectFour
cpu_config: CpuConfig
agent: Agent[Any, int] | None = None
def setup(self) -> None:
logger.debug("AppState setup called")
self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
self.env.reset(seed=123)
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.MCTS, simulations=500)
self.set_agent(self.cpu_config) # TODO in properties setter.
def __new__(cls: type[AppState]) -> AppState: # noqa: PYI034
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.setup()
return cls._instance
def set_config(self, cpu_config: CpuConfig) -> None:
logger.info(f"new cpu_config: {cpu_config}")
if cpu_config != self.cpu_config:
self.set_agent(cpu_config)
self.cpu_config = cpu_config
else:
logger.info("cpu_config unchanged")
def create_rollout(self, rollout_policy: RolloutPolicy) -> Agent[Any, Any]:
if rollout_policy == RolloutPolicy.SAC:
return OnnxSacDeterministicMultiAgent()
return RandomMultiAgent(np.random.default_rng(seed=123))
def can_reuse_mcts_computations(self, cpu_config: CpuConfig) -> bool:
return (
self.agent is not None
and isinstance(self.agent, MCTSAgent)
and self.agent.mcts is not None
and self.cpu_config.agent_type == AgentType.MCTS
and self.cpu_config.rollout_policy != cpu_config.rollout_policy
)
def set_agent(self, cpu_config: CpuConfig) -> None:
if cpu_config.agent_type == AgentType.MCTS:
if not self.can_reuse_mcts_computations(cpu_config):
# fmt: off
mcts_config = (
MCTSConfigBuilder()
.set_simulations(self.cpu_config.simulations or 50)
.set_rollout_strategy(VanillaRollout(rollout_agent=self.create_rollout(cpu_config.rollout_policy)))
).build()
# fmt: on
self.agent = MCTSAgent(cfg=mcts_config)
logger.debug("set_agent: MCTSAgent")
else:
if self.agent is None or not isinstance(self.agent, MCTSAgent) or self.agent.mcts is None:
raise ValueError
self.agent.mcts.cfg.simulations = cpu_config.simulations
elif cpu_config.agent_type == AgentType.RANDOM:
self.agent = RandomMultiAgent()
elif cpu_config.agent_type == AgentType.SAC:
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
else:
msg = f"cpu_config.name: {cpu_config.agent_type}"
raise NotImplementedError(msg)
def get_action(self) -> int:
if self.agent is None:
raise ValueError
return self.agent.get_action(env=self.env)
def step(self, action: int) -> GridResponseType:
if isinstance(self.agent, MCTSAgent):
self.agent.step(self.env, action)
else:
self.env.step(action)
return self.observe()
def reset(self) -> GridResponseType:
if isinstance(self.agent, MCTSAgent):
self.agent.reset(self.env)
else:
self.env.reset()
return self.observe()
def observe(self) -> None:
obs = self.env.observe("player_1")
return GridResponseType( # type: ignore[no-any-return]
grid=obs["observation"].tolist(),
done=bool(
self.env.terminations[self.env.agent_selection] or self.env.truncations[self.env.agent_selection],
), # TODO why needed?
)