Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
0ceb721
1
Parent(s):
01901c5
Updates
Browse files- .gitignore +2 -0
- MonteCarloAgent.py +307 -0
- mc/mc_test.py +0 -43
- mc/mc_train.py +0 -109
- mc/policy.npy +0 -0
.gitignore
CHANGED
@@ -178,3 +178,5 @@ pyrightconfig.json
|
|
178 |
.DS_Store
|
179 |
.idea
|
180 |
.vscode
|
|
|
|
|
|
178 |
.DS_Store
|
179 |
.idea
|
180 |
.vscode
|
181 |
+
|
182 |
+
wandb
|
MonteCarloAgent.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gymnasium as gym
|
3 |
+
from tqdm import tqdm
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import wandb
|
7 |
+
|
8 |
+
|
9 |
+
class MonteCarloAgent:
|
10 |
+
def __init__(self, env_name="CliffWalking-v0", gamma=0.99, epsilon=0.1, **kwargs):
|
11 |
+
print(f"# MonteCarloAgent - {env_name}")
|
12 |
+
print(f"- epsilon: {epsilon}")
|
13 |
+
print(f"- gamma: {gamma}")
|
14 |
+
self.env = gym.make(env_name, **kwargs)
|
15 |
+
self.epsilon, self.gamma = epsilon, gamma
|
16 |
+
self.n_states, self.n_actions = (
|
17 |
+
self.env.observation_space.n,
|
18 |
+
self.env.action_space.n,
|
19 |
+
)
|
20 |
+
print(f"- n_states: {self.n_states}")
|
21 |
+
print(f"- n_actions: {self.n_actions}")
|
22 |
+
self.reset()
|
23 |
+
|
24 |
+
def reset(self):
|
25 |
+
print("Resetting all state variables...")
|
26 |
+
self.Q = np.zeros((self.n_states, self.n_actions))
|
27 |
+
self.R = [[[] for _ in range(self.n_actions)] for _ in range(self.n_states)]
|
28 |
+
|
29 |
+
# An arbitrary e-greedy policy
|
30 |
+
self.Pi = np.full(
|
31 |
+
(self.n_states, self.n_actions), self.epsilon / self.n_actions
|
32 |
+
)
|
33 |
+
self.Pi[
|
34 |
+
np.arange(self.n_states),
|
35 |
+
np.random.randint(self.n_actions, size=self.n_states),
|
36 |
+
] = (
|
37 |
+
1 - self.epsilon + self.epsilon / self.n_actions
|
38 |
+
)
|
39 |
+
print("=" * 80)
|
40 |
+
print("Initial policy:")
|
41 |
+
print(self.Pi)
|
42 |
+
print("=" * 80)
|
43 |
+
|
44 |
+
def choose_action(self, state):
|
45 |
+
# Sample an action from the policy
|
46 |
+
return np.random.choice(self.n_actions, p=self.Pi[state])
|
47 |
+
|
48 |
+
def run_episode(self, max_steps=500, **kwargs):
|
49 |
+
state, _ = self.env.reset()
|
50 |
+
episode_hist = []
|
51 |
+
finished = False
|
52 |
+
# Generate an episode following the current policy
|
53 |
+
for _ in range(max_steps):
|
54 |
+
# Sample an action from the policy
|
55 |
+
action = self.choose_action(state)
|
56 |
+
# Take the action and observe the reward and next state
|
57 |
+
next_state, reward, finished, _, _ = self.env.step(action)
|
58 |
+
# Keeping track of the trajectory
|
59 |
+
episode_hist.append((state, action, reward))
|
60 |
+
state = next_state
|
61 |
+
# This is where the agent got to the goal.
|
62 |
+
# In the case in which agent jumped off the cliff, it is simply respawned at the start position without termination.
|
63 |
+
if finished:
|
64 |
+
break
|
65 |
+
|
66 |
+
return episode_hist, finished
|
67 |
+
|
68 |
+
def update(self, episode_hist):
|
69 |
+
G = 0
|
70 |
+
# For each step of the episode, in reverse order
|
71 |
+
for t in range(len(episode_hist) - 1, -1, -1):
|
72 |
+
state, action, reward = episode_hist[t]
|
73 |
+
# Update the expected return
|
74 |
+
G = self.gamma * G + reward
|
75 |
+
# If we haven't already visited this state-action pair up to this point, then we can update the Q-table and policy
|
76 |
+
# This is the first-visit MC method
|
77 |
+
if (state, action) not in [(x[0], x[1]) for x in episode_hist[:t]]:
|
78 |
+
self.R[state][action].append(G)
|
79 |
+
self.Q[state, action] = np.mean(self.R[state][action])
|
80 |
+
# Epsilon-greedy policy update
|
81 |
+
self.Pi[state] = np.full(self.n_actions, self.epsilon / self.n_actions)
|
82 |
+
# the greedy action is the one with the highest Q-value
|
83 |
+
self.Pi[state, np.argmax(self.Q[state])] = (
|
84 |
+
1 - self.epsilon + self.epsilon / self.n_actions
|
85 |
+
)
|
86 |
+
|
87 |
+
def train(self, n_train_episodes=2500, test_every=100, log_wandb=False, **kwargs):
|
88 |
+
print(f"Training agent for {n_train_episodes} episodes...")
|
89 |
+
train_running_success_rate, test_success_rate = 0.0, 0.0
|
90 |
+
stats = {
|
91 |
+
"train_running_success_rate": train_running_success_rate,
|
92 |
+
"test_success_rate": test_success_rate,
|
93 |
+
}
|
94 |
+
tqrange = tqdm(range(n_train_episodes))
|
95 |
+
tqrange.set_description("Training")
|
96 |
+
|
97 |
+
if log_wandb:
|
98 |
+
self.wandb_log_img(episode=None)
|
99 |
+
|
100 |
+
for e in tqrange:
|
101 |
+
episode_hist, finished = self.run_episode(**kwargs)
|
102 |
+
rewards = [x[2] for x in episode_hist]
|
103 |
+
total_reward, avg_reward = sum(rewards), np.mean(rewards)
|
104 |
+
train_running_success_rate = (
|
105 |
+
0.99 * train_running_success_rate + 0.01 * finished
|
106 |
+
)
|
107 |
+
self.update(episode_hist)
|
108 |
+
|
109 |
+
stats = {
|
110 |
+
"train_running_success_rate": train_running_success_rate,
|
111 |
+
"test_success_rate": test_success_rate,
|
112 |
+
"total_reward": total_reward,
|
113 |
+
"avg_reward": avg_reward,
|
114 |
+
}
|
115 |
+
tqrange.set_postfix(stats)
|
116 |
+
|
117 |
+
if e % test_every == 0:
|
118 |
+
test_success_rate = self.test(verbose=False, **kwargs)
|
119 |
+
|
120 |
+
if log_wandb:
|
121 |
+
self.wandb_log_img(episode=e)
|
122 |
+
|
123 |
+
stats["test_success_rate"] = test_success_rate
|
124 |
+
tqrange.set_postfix(stats)
|
125 |
+
|
126 |
+
if log_wandb:
|
127 |
+
wandb.log(stats)
|
128 |
+
|
129 |
+
def test(self, n_test_episodes=50, verbose=True, **kwargs):
|
130 |
+
if verbose:
|
131 |
+
print(f"Testing agent for {n_test_episodes} episodes...")
|
132 |
+
num_successes = 0
|
133 |
+
for e in range(n_test_episodes):
|
134 |
+
_, finished = self.run_episode(**kwargs)
|
135 |
+
num_successes += finished
|
136 |
+
if verbose:
|
137 |
+
word = "reached" if finished else "did not reach"
|
138 |
+
emoji = "🏁" if finished else "🚫"
|
139 |
+
print(
|
140 |
+
f"({e + 1:>{len(str(n_test_episodes))}}/{n_test_episodes}) - Agent {word} the goal {emoji}"
|
141 |
+
)
|
142 |
+
|
143 |
+
success_rate = num_successes / n_test_episodes
|
144 |
+
if verbose:
|
145 |
+
print(
|
146 |
+
f"Agent reached the goal in {num_successes}/{n_test_episodes} episodes ({success_rate * 100:.2f}%)"
|
147 |
+
)
|
148 |
+
return success_rate
|
149 |
+
|
150 |
+
def save_policy(self, fname="policy.npy"):
|
151 |
+
print(f"Saving policy to {fname}")
|
152 |
+
np.save(fname, self.Pi)
|
153 |
+
|
154 |
+
def load_policy(self, fname="policy.npy"):
|
155 |
+
print(f"Loading policy from {fname}")
|
156 |
+
self.Pi = np.load(fname)
|
157 |
+
|
158 |
+
def wandb_log_img(self, episode=None, mask=None):
|
159 |
+
caption_suffix = "Initial" if episode is None else f"After Episode {episode}"
|
160 |
+
wandb.log(
|
161 |
+
{
|
162 |
+
"Q-table": wandb.Image(
|
163 |
+
self.Q,
|
164 |
+
caption=f"Q-table - {caption_suffix}",
|
165 |
+
),
|
166 |
+
"Policy": wandb.Image(
|
167 |
+
self.Pi,
|
168 |
+
caption=f"Policy - {caption_suffix}",
|
169 |
+
),
|
170 |
+
}
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
def main():
|
175 |
+
parser = argparse.ArgumentParser()
|
176 |
+
|
177 |
+
### Train/Test parameters
|
178 |
+
parser.add_argument(
|
179 |
+
"--train",
|
180 |
+
action="store_true",
|
181 |
+
help="Use this flag to train the agent. (default: False)",
|
182 |
+
)
|
183 |
+
parser.add_argument(
|
184 |
+
"--test",
|
185 |
+
type=str,
|
186 |
+
default=None,
|
187 |
+
help="Use this flag to test the agent. Provide the path to the policy file.",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--n_train_episodes",
|
191 |
+
type=int,
|
192 |
+
default=2000,
|
193 |
+
help="The number of episodes to train for.",
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--n_test_episodes",
|
197 |
+
type=int,
|
198 |
+
default=250,
|
199 |
+
help="The number of episodes to test for.",
|
200 |
+
)
|
201 |
+
parser.add_argument(
|
202 |
+
"--test_every",
|
203 |
+
type=int,
|
204 |
+
default=250,
|
205 |
+
help="During training, test the agent every n episodes.",
|
206 |
+
)
|
207 |
+
|
208 |
+
parser.add_argument(
|
209 |
+
"--max_steps",
|
210 |
+
type=int,
|
211 |
+
default=500,
|
212 |
+
help="The maximum number of steps per episode before the episode is forced to end.",
|
213 |
+
)
|
214 |
+
|
215 |
+
### Agent parameters
|
216 |
+
parser.add_argument(
|
217 |
+
"--gamma",
|
218 |
+
type=float,
|
219 |
+
default=0.99,
|
220 |
+
help="The value for the discount factor to use.",
|
221 |
+
)
|
222 |
+
parser.add_argument(
|
223 |
+
"--epsilon",
|
224 |
+
type=float,
|
225 |
+
default=0.1,
|
226 |
+
help="The value for the epsilon-greedy policy to use.",
|
227 |
+
)
|
228 |
+
|
229 |
+
### Environment parameters
|
230 |
+
parser.add_argument(
|
231 |
+
"--env",
|
232 |
+
type=str,
|
233 |
+
default="CliffWalking-v0",
|
234 |
+
help="The Gymnasium environment to use.",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--render_mode",
|
238 |
+
type=str,
|
239 |
+
default=None,
|
240 |
+
help="The render mode to use. By default, no rendering is done. To render the environment, set this to 'human'.",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--wandb_project",
|
244 |
+
type=str,
|
245 |
+
default=None,
|
246 |
+
help="WandB project name for logging. If not provided, no logging is done.",
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--wandb_group",
|
250 |
+
type=str,
|
251 |
+
default="monte-carlo",
|
252 |
+
help="WandB group name for logging. (default: monte-carlo)",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--wandb_job_type",
|
256 |
+
type=str,
|
257 |
+
default="train",
|
258 |
+
help="WandB job type for logging. (default: train)",
|
259 |
+
)
|
260 |
+
|
261 |
+
args = parser.parse_args()
|
262 |
+
|
263 |
+
mca = MonteCarloAgent(
|
264 |
+
args.env,
|
265 |
+
gamma=args.gamma,
|
266 |
+
epsilon=args.epsilon,
|
267 |
+
render_mode=args.render_mode,
|
268 |
+
)
|
269 |
+
|
270 |
+
run_name = f"mc_{args.env}_e{args.n_train_episodes}_s{args.max_steps}_g{args.gamma}_e{args.epsilon}"
|
271 |
+
|
272 |
+
try:
|
273 |
+
if args.train:
|
274 |
+
# Log to WandB
|
275 |
+
if args.wandb_project is not None:
|
276 |
+
wandb.init(
|
277 |
+
project=args.wandb_project,
|
278 |
+
name=run_name,
|
279 |
+
group=args.wandb_group,
|
280 |
+
job_type=args.wandb_job_type,
|
281 |
+
config=dict(args._get_kwargs()),
|
282 |
+
)
|
283 |
+
|
284 |
+
mca.train(
|
285 |
+
n_train_episodes=args.n_train_episodes,
|
286 |
+
test_every=args.test_every,
|
287 |
+
n_test_episodes=args.n_test_episodes,
|
288 |
+
max_steps=args.max_steps,
|
289 |
+
log_wandb=args.wandb_project is not None,
|
290 |
+
)
|
291 |
+
mca.save_policy(fname=f"policy_{run_name}.npy")
|
292 |
+
elif args.test is not None:
|
293 |
+
if not args.test.endswith(".npy"):
|
294 |
+
args.test += ".npy"
|
295 |
+
mca.load_policy(args.test)
|
296 |
+
mca.test(
|
297 |
+
n_test_episodes=args.n_test_episodes,
|
298 |
+
max_steps=args.max_steps,
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
print("ERROR: Please provide either --train or --test.")
|
302 |
+
except KeyboardInterrupt:
|
303 |
+
print("Exiting...")
|
304 |
+
|
305 |
+
|
306 |
+
if __name__ == "__main__":
|
307 |
+
main()
|
mc/mc_test.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import gymnasium as gym
|
3 |
-
from tqdm import tqdm
|
4 |
-
|
5 |
-
policy_file = "policy.npy"
|
6 |
-
n_steps = 500
|
7 |
-
n_test_episodes = 10
|
8 |
-
|
9 |
-
|
10 |
-
def main():
|
11 |
-
print("=" * 80)
|
12 |
-
print("# Cliff Walking - Monte Carlo Test")
|
13 |
-
print("=" * 80)
|
14 |
-
# save the policy
|
15 |
-
print(f"Loading policy from file: '{policy_file}'...")
|
16 |
-
Pi = np.load(policy_file)
|
17 |
-
print("Policy:")
|
18 |
-
print(Pi)
|
19 |
-
print(f"shape: {Pi.shape}")
|
20 |
-
_, n_actions = Pi.shape
|
21 |
-
|
22 |
-
print("=" * 80)
|
23 |
-
print(f"Testing policy for {n_test_episodes} episodes...")
|
24 |
-
env = gym.make("CliffWalking-v0", render_mode="human")
|
25 |
-
for e in range(n_test_episodes):
|
26 |
-
print(f"Test #{e + 1}:", end=" ")
|
27 |
-
|
28 |
-
state, _ = env.reset()
|
29 |
-
for _ in range(n_steps):
|
30 |
-
action = np.random.choice(n_actions, p=Pi[state])
|
31 |
-
next_state, reward, done, _, _ = env.step(action)
|
32 |
-
state = next_state
|
33 |
-
if done:
|
34 |
-
print("Success!")
|
35 |
-
break
|
36 |
-
else:
|
37 |
-
print("Failed!")
|
38 |
-
|
39 |
-
env.close()
|
40 |
-
|
41 |
-
|
42 |
-
if __name__ == "__main__":
|
43 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mc/mc_train.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import gymnasium as gym
|
3 |
-
from tqdm import tqdm
|
4 |
-
|
5 |
-
|
6 |
-
def main():
|
7 |
-
print("# Cliff Walking - Monte Carlo Train")
|
8 |
-
env = gym.make("CliffWalking-v0")
|
9 |
-
|
10 |
-
# Training parameters
|
11 |
-
gamma, epsilon = 0.99, 0.1
|
12 |
-
n_train_episodes, n_test_episodes, n_max_steps = 2000, 10, 500
|
13 |
-
n_states, n_actions = env.observation_space.n, env.action_space.n
|
14 |
-
print("=" * 80)
|
15 |
-
print(f"gamma: {gamma}")
|
16 |
-
print(f"epsilon: {epsilon}")
|
17 |
-
print(f"n_episodes: {n_train_episodes}")
|
18 |
-
print(f"n_steps: {n_max_steps}")
|
19 |
-
print(f"n_states: {n_states}")
|
20 |
-
print(f"n_actions: {n_actions}")
|
21 |
-
print("=" * 80)
|
22 |
-
|
23 |
-
# An arbitrary e-greedy policy
|
24 |
-
Pi = np.full((n_states, n_actions), epsilon / n_actions)
|
25 |
-
Pi[np.arange(n_states), np.random.randint(n_actions, size=n_states)] = (
|
26 |
-
1 - epsilon + epsilon / n_actions
|
27 |
-
)
|
28 |
-
print("=" * 80)
|
29 |
-
print("Initial policy:")
|
30 |
-
print(Pi)
|
31 |
-
print("=" * 80)
|
32 |
-
Q = np.zeros((n_states, n_actions))
|
33 |
-
R = [[[] for _ in range(n_actions)] for _ in range(n_states)]
|
34 |
-
|
35 |
-
successes = []
|
36 |
-
tqrange = tqdm(range(n_train_episodes))
|
37 |
-
for i in tqrange:
|
38 |
-
tqrange.set_description(f"Episode {i + 1:>4}")
|
39 |
-
state, _ = env.reset()
|
40 |
-
# Generate an episode following the current policy
|
41 |
-
episode = []
|
42 |
-
for _ in range(n_max_steps):
|
43 |
-
# Randomly choose an action from the e-greedy policy
|
44 |
-
action = np.random.choice(n_actions, p=Pi[state])
|
45 |
-
# Take the action and observe the reward and next state
|
46 |
-
next_state, reward, done, _, _ = env.step(action)
|
47 |
-
episode.append((state, action, reward))
|
48 |
-
state = next_state
|
49 |
-
# This is where the agent got to the goal.
|
50 |
-
# In the case in which agent jumped off the cliff, it is simply respawned at the start position without termination.
|
51 |
-
if done:
|
52 |
-
successes.append(1)
|
53 |
-
break
|
54 |
-
else:
|
55 |
-
successes.append(0)
|
56 |
-
|
57 |
-
G = 0
|
58 |
-
# For each step of the episode, in reverse order
|
59 |
-
for t in range(len(episode) - 1, -1, -1):
|
60 |
-
state, action, reward = episode[t]
|
61 |
-
# Update the expected return
|
62 |
-
G = gamma * G + reward
|
63 |
-
# If we haven't already visited this state-action pair up to this point, then we can update the Q-table and policy
|
64 |
-
# This is the first-visit MC method
|
65 |
-
if (state, action) not in [(x[0], x[1]) for x in episode[:t]]:
|
66 |
-
R[state][action].append(G)
|
67 |
-
Q[state, action] = np.mean(R[state][action])
|
68 |
-
# e-greedy policy update
|
69 |
-
Pi[state] = np.full(n_actions, epsilon / n_actions)
|
70 |
-
# the greedy action is the one with the highest Q-value
|
71 |
-
Pi[state, np.argmax(Q[state])] = 1 - epsilon + epsilon / n_actions
|
72 |
-
|
73 |
-
success_rate_100 = np.mean(successes[-100:])
|
74 |
-
success_rate_250 = np.mean(successes[-250:])
|
75 |
-
success_rate_500 = np.mean(successes[-500:])
|
76 |
-
tqrange.set_postfix(
|
77 |
-
success_rate_100=f"{success_rate_100:.3f}",
|
78 |
-
success_rate_250=f"{success_rate_250:.3f}",
|
79 |
-
success_rate_500=f"{success_rate_500:.3f}",
|
80 |
-
)
|
81 |
-
|
82 |
-
print("Final policy:")
|
83 |
-
print(Pi)
|
84 |
-
np.save("policy.npy", Pi)
|
85 |
-
|
86 |
-
print("=" * 80)
|
87 |
-
print(f"Testing policy for {n_test_episodes} episodes...")
|
88 |
-
# Test the policy for a few episodes
|
89 |
-
env = gym.make("CliffWalking-v0", render_mode="human")
|
90 |
-
for e in range(n_test_episodes):
|
91 |
-
print(f"Test #{e + 1}:", end=" ")
|
92 |
-
|
93 |
-
state, _ = env.reset()
|
94 |
-
for _ in range(n_max_steps):
|
95 |
-
action = np.random.choice(n_actions, p=Pi[state])
|
96 |
-
next_state, reward, done, _, _ = env.step(action)
|
97 |
-
state = next_state
|
98 |
-
if done:
|
99 |
-
print("Success!")
|
100 |
-
break
|
101 |
-
else:
|
102 |
-
print("Failed!")
|
103 |
-
|
104 |
-
# Close the environment
|
105 |
-
env.close()
|
106 |
-
|
107 |
-
|
108 |
-
if __name__ == "__main__":
|
109 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mc/policy.npy
DELETED
Binary file (1.66 kB)
|
|