Spaces:
Build error
Build error
File size: 2,675 Bytes
bafb458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from typing import Any
from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfig
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.algo.sac.agent import OnnxSacDeterministicAgent
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.env.connect_four import Board
from litrl.env.set_state import set_state
from src.typing import AgentType, CpuConfig, RolloutPolicy
class AppState:
def __init__(self) -> None:
self.env = make_multiagent(id="ConnectFour-v3", render_mode="human")
self.env.reset(seed=123)
self.agent: Agent[Any, Any] | None = None
self.cpu_config: CpuConfig | None = None
def set_board(self, board: Board) -> None:
set_state(env=self.env, board=board)
def set_config(self, cpu_config: CpuConfig) -> None:
if (
self.agent is None
or self.cpu_config is None
or cpu_config != self.cpu_config
):
self.cpu_config = cpu_config
self.set_agent()
def create_rollout(self) -> Agent[Any, Any]:
if self.cpu_config is None:
raise ValueError("self.cpu_config is None")
match self.cpu_config.rollout_policy:
case None:
return RandomMultiAgent()
case RolloutPolicy.SAC:
return OnnxSacDeterministicAgent()
case RolloutPolicy.RANDOM:
return RandomMultiAgent()
case _:
raise NotImplementedError(
f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
)
def set_agent(self) -> None:
if self.cpu_config is None:
raise ValueError("self.cpu_config is None")
match self.cpu_config.agent_type:
case AgentType.MCTS:
rollout_agent = self.create_rollout()
mcts_config = MCTSConfig(
simulations=self.cpu_config.simulations or 50,
rollout_strategy=VanillaRollout(rollout_agent=rollout_agent),
)
self.agent = MCTSAgent(cfg=mcts_config)
case AgentType.RANDOM:
self.agent = RandomMultiAgent()
case AgentType.SAC:
self.agent = OnnxSacDeterministicAgent()
case _:
raise NotImplementedError(
f"cpu_config.name: {self.cpu_config.agent_type}"
)
def get_action(self) -> int:
if self.agent is None:
raise ValueError("self.agent is None")
return self.agent.get_action(env=self.env)
|