File size: 3,296 Bytes
5b344d4
 
 
5cd7fc9
 
5b344d4
5cd7fc9
5b344d4
 
bafb458
5b344d4
 
 
 
 
 
7011484
bafb458
 
7011484
bafb458
 
7011484
76c534f
7011484
bafb458
 
6e7d45d
 
 
7011484
6e7d45d
 
 
7011484
bafb458
 
6e7d45d
 
acf3b96
6e7d45d
5b344d4
6e7d45d
 
 
 
bafb458
 
acf3b96
7011484
bafb458
 
acf3b96
 
bafb458
 
5cd7fc9
 
 
bafb458
 
5cd7fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bafb458
 
 
acf3b96
5b344d4
 
 
 
acf3b96
 
5cd7fc9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from loguru import logger

if TYPE_CHECKING:
    import sys

    from litrl.env.connect_four import ConnectFour

    if sys.version_info[:2] >= (3, 11):
        from typing import Self
    else:
        from typing_extensions import Self

from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfigBuilder
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.model.sac.multi_agent import OnnxSacDeterministicMultiAgent
from src.typing import AgentType, CpuConfig, RolloutPolicy


class AppState:
    _instance: Self | None = None
    env: ConnectFour
    cpu_config: CpuConfig
    agent: Agent[Any, int]

    def setup(self) -> None:
        logger.debug("AppState setup called")
        self.env = make_multiagent(id="connect_four", render_mode="rgb_array")
        self.env.reset(seed=123)

        self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
        self.set_agent()  # TODO in properties setter.
        self.agent: Agent[Any, Any]

    def __new__(cls: type[AppState]) -> AppState:  # noqa: PYI034
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance.setup()
        return cls._instance

    def set_config(self, cpu_config: CpuConfig) -> None:
        logger.info(f"new cpu_config: {cpu_config}")
        if cpu_config != self.cpu_config:
            self.cpu_config = cpu_config
            self.set_agent()
        else:
            logger.info("cpu_config unchanged")

    def create_rollout(self) -> Agent[Any, Any]:
        if self.cpu_config.rollout_policy == RolloutPolicy.SAC:
            return OnnxSacDeterministicMultiAgent()
        return RandomMultiAgent(np.random.default_rng(seed=123))

    def set_agent(self) -> None:
        if self.cpu_config.agent_type.value == AgentType.MCTS.value:
            rollout_agent = self.create_rollout()
            # fmt: off
            mcts_config = (
                MCTSConfigBuilder()
                .set_simulations(self.cpu_config.simulations or 50)
                .set_rollout_strategy(VanillaRollout(rollout_agent=rollout_agent))
            ).build()
            # fmt: on
            self.agent = MCTSAgent(cfg=mcts_config)
            logger.debug("set_agent: MCTSAgent")
        elif self.cpu_config.agent_type.value == AgentType.RANDOM.value:
            self.agent = RandomMultiAgent()
        elif self.cpu_config.agent_type.value == AgentType.SAC.value:
            self.agent = OnnxSacDeterministicMultiAgent()  # type: ignore[assignment]  # TODO
        else:
            msg = f"cpu_config.name: {self.cpu_config.agent_type}"
            raise NotImplementedError(msg)

    def get_action(self) -> int:
        return self.agent.get_action(env=self.env)

    def inform_reset(self) -> None:
        if isinstance(self.agent, MCTSAgent):
            self.agent.inform_reset()

    def inform_action(self, action: int) -> None:
        """Update the agent's state as a result of external changes to the environment."""
        if isinstance(self.agent, MCTSAgent):
            self.agent.inform_action(action)