File size: 2,810 Bytes
6e7d45d
bafb458
 
 
 
 
 
 
 
 
6e7d45d
bafb458
6e7d45d
bafb458
 
6e7d45d
 
 
 
 
 
 
 
bafb458
 
6e7d45d
 
 
 
 
 
 
 
 
 
bafb458
 
 
 
 
6e7d45d
bafb458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Any, Self

from litrl import make_multiagent
from litrl.algo.mcts.agent import MCTSAgent
from litrl.algo.mcts.mcts_config import MCTSConfig
from litrl.algo.mcts.rollout import VanillaRollout
from litrl.algo.sac.agent import OnnxSacDeterministicAgent
from litrl.common.agent import Agent, RandomMultiAgent
from litrl.env.connect_four import Board
from litrl.env.set_state import set_state
from loguru import logger
from src.typing import AgentType, CpuConfig, RolloutPolicy
from litrl.env.connect_four import ConnectFour

class AppState:
    _instance: Self | None = None
    env: ConnectFour
    cpu_config: CpuConfig
    agent: Agent[Any, Any]

    def setup(self) -> None:
        logger.debug("AppState setup called")
        self.env = make_multiagent(id="ConnectFour-v3", render_mode="rgb_array")
        self.env.reset(seed=123)

        self.cpu_config: CpuConfig = CpuConfig(agent_type=AgentType.RANDOM)
        self.set_agent()  # TODO in properties setter.
        self.agent: Agent[Any, Any]

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
            cls._instance.setup()
        return cls._instance
    
    def set_board(self, board: Board) -> None:
        set_state(env=self.env, board=board)

    def set_config(self, cpu_config: CpuConfig) -> None:
        if (
            cpu_config != self.cpu_config
        ):
            self.cpu_config = cpu_config
            self.set_agent()

    def create_rollout(self) -> Agent[Any, Any]:
        match self.cpu_config.rollout_policy:
            case None:
                return RandomMultiAgent()
            case RolloutPolicy.SAC:
                return OnnxSacDeterministicAgent()
            case RolloutPolicy.RANDOM:
                return RandomMultiAgent()
            case _:
                raise NotImplementedError(
                    f"cpu_config.rollout_policy: {self.cpu_config.rollout_policy}"
                )

    def set_agent(self) -> None:
        match self.cpu_config.agent_type:
            case AgentType.MCTS:
                rollout_agent = self.create_rollout()
                mcts_config = MCTSConfig(
                    simulations=self.cpu_config.simulations or 50,
                    rollout_strategy=VanillaRollout(rollout_agent=rollout_agent),
                )
                self.agent = MCTSAgent(cfg=mcts_config)
            case AgentType.RANDOM:
                self.agent = RandomMultiAgent()
            case AgentType.SAC:
                self.agent = OnnxSacDeterministicAgent()
            case _:
                raise NotImplementedError(
                    f"cpu_config.name: {self.cpu_config.agent_type}"
                )

    def get_action(self) -> int:
        return self.agent.get_action(env=self.env)