File size: 2,762 Bytes
6e7d45d
bafb458
7011484
 
 
bafb458
 
7011484
bafb458
 
6e7d45d
7011484
 
bafb458
 
6e7d45d
 
 
7011484
6e7d45d
 
 
7011484
bafb458
 
6e7d45d
 
7011484
6e7d45d
7011484
6e7d45d
 
 
 
bafb458
 
7011484
bafb458
 
 
 
 
 
 
 
7011484
bafb458
 
 
7011484
 
bafb458
 
 
 
 
7011484
 
 
 
 
 
 
bafb458
 
 
 
7011484
bafb458
7011484
 
bafb458
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Any, Self

from loguru import logger
from src.typing import AgentType, CpuConfig, RolloutPolicy

from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.env.connect_four import ConnectFour
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent


class AppState:
    _instance: Self | None = None
    env: ConnectFour
    cpu_config: CpuConfig
    agent: Agent[Any, int]

    def setup(self) -> None:
        logger.debug("AppState setup called")
        self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
        self.env.reset(seed=123)

        self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
        self.set_agent()  # TODO in properties setter.
        self.agent: Agent[Any, int]

    def __new__(cls) -> "AppState":
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance.setup()
        return cls._instance

    def set_config(self, cpu_config: CpuConfig) -> None:
        if cpu_config != self.cpu_config:
            self.cpu_config = cpu_config
            self.set_agent()

    def create_rollout(self) -> Agent[Any, Any]:
        match self.cpu_config.rollout_policy:
            case None:
                return RandomMultiAgent()
            case RolloutPolicy.SAC:
                return OnnxSacDeterministicMultiAgent()
            case RolloutPolicy.RANDOM:
                return RandomMultiAgent()
            case _:
                msg = f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
                raise NotImplementedError(msg)

    def set_agent(self) -> None:
        match self.cpu_config.agent_type:
            case AgentType.MCTS:
                rollout_agent = self.create_rollout()
                # fmt: off
                mcts_config = (
                    MCTSConfigBuilder()
                    .set_simulations(self.cpu_config.simulations or 50)
                    .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
                ).build()
                # fmt: on
                self.agent = MCTSAgent(cfg=mcts_config)
            case AgentType.RANDOM:
                self.agent = RandomMultiAgent()
            case AgentType.SAC:
                self.agent = OnnxSacDeterministicMultiAgent()  # type: ignore[assignment]  # TODO
            case _:
                msg = f"cpu_config.name: {self.cpu_config.agent_type}"
                raise NotImplementedError(msg)

    def get_action(self) -> int:
        return self.agent.get_action(env=self.env)