File size: 6,407 Bytes
0ceb721
 
d678220
0ceb721
5fc752e
0ceb721
99ac186
3e2038a
1b140f8
6ee82fe
0ceb721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73cd2cf
0ceb721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73cd2cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
080e344
 
 
 
 
 
1663f39
 
e17747a
080e344
 
0ceb721
6ee82fe
080e344
1663f39
 
 
 
 
 
 
0ceb721
 
1663f39
0ceb721
1663f39
0ceb721
080e344
 
 
0ceb721
 
 
 
 
 
 
8ceccef
0ceb721
 
1663f39
0ceb721
8ceccef
0ceb721
1663f39
 
080e344
0ceb721
 
 
1663f39
0ceb721
1663f39
0ceb721
 
 
 
 
1663f39
0ceb721
 
 
 
 
1663f39
 
 
 
0ceb721
 
 
 
 
 
6ee82fe
99ac186
 
 
 
 
e17747a
 
 
 
 
 
0ceb721
8ceccef
0ceb721
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import numpy as np
from tqdm import tqdm
from Shared import Shared
import wandb
from Shared import Shared


class MCAgent(Shared):
    def __init__(self, /, **kwargs):
        super().__init__(run_name=self.__class__.__name__, **kwargs)
        self.reset()

    def reset(self):
        print("Resetting all state variables...")
        self.Q = np.zeros((self.n_states, self.n_actions))
        self.R = [[[] for _ in range(self.n_actions)] for _ in range(self.n_states)]

        # An arbitrary e-greedy policy
        self.Pi = np.full(
            (self.n_states, self.n_actions), self.epsilon / self.n_actions
        )
        self.Pi[
            np.arange(self.n_states),
            np.random.randint(self.n_actions, size=self.n_states),
        ] = (
            1 - self.epsilon + self.epsilon / self.n_actions
        )
        print("=" * 80)
        print("Initial policy:")
        print(self.Pi)
        print("=" * 80)

    def update_first_visit(self, episode_hist):
        G = 0
        # For each step of the episode, in reverse order
        for t in range(len(episode_hist) - 1, -1, -1):
            state, action, reward = episode_hist[t]
            # Update the expected return
            G = self.gamma * G + reward
            # If we haven't already visited this state-action pair up to this point, then we can update the Q-table and policy
            # This is the first-visit MC method
            if (state, action) not in [(x[0], x[1]) for x in episode_hist[:t]]:
                self.R[state][action].append(G)
                self.Q[state, action] = np.mean(self.R[state][action])
                # Epsilon-greedy policy update
                self.Pi[state] = np.full(self.n_actions, self.epsilon / self.n_actions)
                # the greedy action is the one with the highest Q-value
                self.Pi[state, np.argmax(self.Q[state])] = (
                    1 - self.epsilon + self.epsilon / self.n_actions
                )

    def update_every_visit(self, episode_hist):
        G = 0
        # For each step of the episode, in reverse order
        for t in range(len(episode_hist) - 1, -1, -1):
            state, action, reward = episode_hist[t]
            # Update the expected return
            G = self.gamma * G + reward
            # We update the Q-table and policy even if we have visited this state-action pair before
            # This is the every-visit MC method
            self.R[state][action].append(G)
            self.Q[state, action] = np.mean(self.R[state][action])
            # Epsilon-greedy policy update
            self.Pi[state] = np.full(self.n_actions, self.epsilon / self.n_actions)
            # the greedy action is the one with the highest Q-value
            self.Pi[state, np.argmax(self.Q[state])] = (
                1 - self.epsilon + self.epsilon / self.n_actions
            )

    def train(
        self,
        n_train_episodes=2000,
        test_every=100,
        update_type="first_visit",
        log_wandb=False,
        save_best=True,
        save_best_dir=None,
        early_stopping=False,
        **kwargs,
    ):
        print(f"Training agent for {n_train_episodes} episodes...")
        self.run_name = f"{self.run_name}_{update_type}"

        (
            train_running_success_rate,
            test_success_rate,
            test_running_success_rate,
            avg_ep_len,
        ) = (0.0, 0.0, 0.0, 0.0)

        stats = {
            "train_running_success_rate": train_running_success_rate,
            "test_running_success_rate": test_running_success_rate,
            "test_success_rate": test_success_rate,
            "avg_ep_len": avg_ep_len,
        }

        update_func = getattr(self, f"update_{update_type}")

        tqrange = tqdm(range(n_train_episodes))
        tqrange.set_description("Training")

        if log_wandb:
            self.wandb_log_img(episode=None)

        for e in tqrange:
            episode_hist, solved, _ = self.run_episode(**kwargs)
            rewards = [x[2] for x in episode_hist]
            total_reward, avg_reward = sum(rewards), np.mean(rewards)

            train_running_success_rate = (
                0.99 * train_running_success_rate + 0.01 * solved
            )
            avg_ep_len = 0.99 * avg_ep_len + 0.01 * len(episode_hist)

            update_func(episode_hist)

            stats = {
                "train_running_success_rate": train_running_success_rate,
                "test_running_success_rate": test_running_success_rate,
                "test_success_rate": test_success_rate,
                "avg_ep_len": avg_ep_len,
                "total_reward": total_reward,
                "avg_reward": avg_reward,
            }
            tqrange.set_postfix(stats)

            # Test the agent every test_every episodes with the greedy policy (by default)
            if e % test_every == 0:
                test_success_rate = self.test(verbose=False, **kwargs)
                if log_wandb:
                    self.wandb_log_img(episode=e)

            test_running_success_rate = (
                0.99 * test_running_success_rate + 0.01 * test_success_rate
            )
            stats["test_running_success_rate"] = test_running_success_rate
            stats["test_success_rate"] = test_success_rate
            tqrange.set_postfix(stats)

            if log_wandb:
                wandb.log(stats)

            if test_running_success_rate > 0.99999:
                if save_best:
                    if self.run_name is None:
                        print("WARNING: run_name is None, not saving best policy.")
                    else:
                        self.save_policy(self.run_name, save_best_dir)

                if early_stopping:
                    print(
                        f"CONVERGED: test success rate running avg reached 100% after {e} episodes."
                    )
                    break

    def wandb_log_img(self, episode=None):
        caption_suffix = "Initial" if episode is None else f"After Episode {episode}"
        wandb.log(
            {
                "Q-table": wandb.Image(
                    self.Q,
                    caption=f"Q-table - {caption_suffix}",
                ),
                "Policy": wandb.Image(
                    self.Pi,
                    caption=f"Policy - {caption_suffix}",
                ),
            }
        )