LitRL-Inference / src /app_state.py
c-gohlke's picture
Upload folder using huggingface_hub
7011484 verified
raw
history blame
2.76 kB
from typing import Any, Self
from loguru import logger
from src.typing import AgentType, CpuConfig, RolloutPolicy
from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.env.connect_four import ConnectFour
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
class AppState:
_instance: Self | None = None
env: ConnectFour
cpu_config: CpuConfig
agent: Agent[Any, int]
def setup(self) -> None:
logger.debug("AppState setup called")
self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
self.env.reset(seed=123)
self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
self.set_agent() # TODO in properties setter.
self.agent: Agent[Any, int]
def __new__(cls) -> "AppState":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.setup()
return cls._instance
def set_config(self, cpu_config: CpuConfig) -> None:
if cpu_config != self.cpu_config:
self.cpu_config = cpu_config
self.set_agent()
def create_rollout(self) -> Agent[Any, Any]:
match self.cpu_config.rollout_policy:
case None:
return RandomMultiAgent()
case RolloutPolicy.SAC:
return OnnxSacDeterministicMultiAgent()
case RolloutPolicy.RANDOM:
return RandomMultiAgent()
case _:
msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
raise NotImplementedError(msg)
def set_agent(self) -> None:
match self.cpu_config.agent_type:
case AgentType.MCTS:
rollout_agent = self.create_rollout()
# fmt: off
mcts_config = (
MCTSConfigBuilder()
.set_simulations(self.cpu_config.simulations or 50)
.set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
).build()
# fmt: on
self.agent = MCTSAgent(cfg=mcts_config)
case AgentType.RANDOM:
self.agent = RandomMultiAgent()
case AgentType.SAC:
self.agent = OnnxSacDeterministicMultiAgent() # type: ignore[assignment] # TODO
case _:
msg = f"cpu_config.name: {self.cpu_config.agent_type}"
raise NotImplementedError(msg)
def get_action(self) -> int:
return self.agent.get_action(env=self.env)