File size: 4,499 Bytes
5b344d4
 
 
5cd7fc9
 
5b344d4
5cd7fc9
5b344d4
 
bafb458
5b344d4
 
 
 
 
 
7011484
bafb458
 
7011484
bafb458
 
7011484
302ae2f
7011484
bafb458
 
6e7d45d
 
 
302ae2f
6e7d45d
 
 
7011484
bafb458
 
302ae2f
 
6e7d45d
5b344d4
6e7d45d
 
 
 
bafb458
 
acf3b96
7011484
302ae2f
bafb458
acf3b96
 
bafb458
302ae2f
 
5cd7fc9
 
bafb458
302ae2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cd7fc9
302ae2f
5cd7fc9
 
302ae2f
5cd7fc9
bafb458
 
302ae2f
 
bafb458
acf3b96
302ae2f
5b344d4
302ae2f
 
 
 
5b344d4
302ae2f
5cd7fc9
302ae2f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from loguru import logger

if TYPE_CHECKING:
    import sys

    from litrl.env.connect_four import ConnectFour

    if sys.version_info[:2] >= (3, 11):
        from typing import Self
    else:
        from typing_extensions import Self

from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
from src.typing import AgentType, CpuConfig, GridResponseType, RolloutPolicy


class AppState:
    _instance: Self | None = None
    env: ConnectFour
    cpu_config: CpuConfig
    agent: Agent[Any, int] | None = None

    def setup(self) -> None:
        logger.debug("AppState setup called")
        self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
        self.env.reset(seed=123)

        self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.MCTS, simulations=500)
        self.set_agent(self.cpu_config)  # TODO in properties setter.

    def __new__(cls: type[AppState]) -> AppState:  # noqa: PYI034
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance.setup()
        return cls._instance

    def set_config(self, cpu_config: CpuConfig) -> None:
        logger.info(f"new cpu_config: {cpu_config}")
        if cpu_config != self.cpu_config:
            self.set_agent(cpu_config)
            self.cpu_config = cpu_config
        else:
            logger.info("cpu_config unchanged")

    def create_rollout(self, rollout_policy: RolloutPolicy) -> Agent[Any, Any]:
        if rollout_policy == RolloutPolicy.SAC:
            return OnnxSacDeterministicMultiAgent()
        return RandomMultiAgent(np.random.default_rng(seed=123))

    def can_reuse_mcts_computations(self, cpu_config: CpuConfig) -> bool:
        return (
            self.agent is not None
            and isinstance(self.agent, MCTSAgent)
            and self.agent.mcts is not None
            and self.cpu_config.agent_type == AgentType.MCTS
            and self.cpu_config.rollout_policy != cpu_config.rollout_policy
        )

    def set_agent(self, cpu_config: CpuConfig) -> None:
        if cpu_config.agent_type == AgentType.MCTS:
            if not self.can_reuse_mcts_computations(cpu_config):
                # fmt: off
                mcts_config = (
                    MCTSConfigBuilder()
                    .set_simulations(self.cpu_config.simulations or 50)
                    .set_rollout_strategy(VanillaRollout(rollout_agent=self.create_rollout(cpu_config.rollout_policy)))
                ).build()
                # fmt: on
                self.agent = MCTSAgent(cfg=mcts_config)
                logger.debug("set_agent: MCTSAgent")
            else:
                if self.agent is None or not isinstance(self.agent, MCTSAgent) or self.agent.mcts is None:
                    raise ValueError
                self.agent.mcts.cfg.simulations = cpu_config.simulations
        elif cpu_config.agent_type == AgentType.RANDOM:
            self.agent = RandomMultiAgent()
        elif cpu_config.agent_type == AgentType.SAC:
            self.agent = OnnxSacDeterministicMultiAgent()  # type: ignore[assignment]  # TODO
        else:
            msg = f"cpu_config.name: {cpu_config.agent_type}"
            raise NotImplementedError(msg)

    def get_action(self) -> int:
        if self.agent is None:
            raise ValueError
        return self.agent.get_action(env=self.env)

    def step(self, action: int) -> GridResponseType:
        if isinstance(self.agent, MCTSAgent):
            self.agent.step(self.env, action)
        else:
            self.env.step(action)
        return self.observe()

    def reset(self) -> GridResponseType:
        if isinstance(self.agent, MCTSAgent):
            self.agent.reset(self.env)
        else:
            self.env.reset()
        return self.observe()

    def observe(self) -> None:
        obs = self.env.observe("player_1")
        return GridResponseType(  # type: ignore[no-any-return]
            grid=obs["observation"].tolist(),
            done=bool(
                self.env.terminations[self.env.agent_selection] or self.env.truncations[self.env.agent_selection],
            ),  # TODO why needed?
        )