Andrei Cozma commited on
Commit
3e2038a
·
1 Parent(s): b8a5bf6
Files changed (5) hide show
  1. DPAgent.py +2 -2
  2. MonteCarloAgent.py +1 -177
  3. agents.py +8 -0
  4. demo.py +5 -32
  5. run.py +187 -0
DPAgent.py CHANGED
@@ -5,7 +5,7 @@ from matplotlib import pyplot as plt
5
  from tqdm import trange
6
 
7
 
8
- class DP:
9
  def __init__(self, env_name, gamma=0.9, theta=1e-10, **kwargs):
10
  self.env = gym.make(env_name, **kwargs)
11
  self.gamma = gamma
@@ -85,7 +85,7 @@ class DP:
85
 
86
  if __name__ == "__main__":
87
  # env = gym.make('FrozenLake-v1', render_mode='human')
88
- dp = DP("FrozenLake-v1", is_slippery=False, desc=[
89
  "SFFFFFFF",
90
  "FFFFFFFH",
91
  "FFFHFFFF",
 
5
  from tqdm import trange
6
 
7
 
8
+ class DPAgent:
9
  def __init__(self, env_name, gamma=0.9, theta=1e-10, **kwargs):
10
  self.env = gym.make(env_name, **kwargs)
11
  self.gamma = gamma
 
85
 
86
  if __name__ == "__main__":
87
  # env = gym.make('FrozenLake-v1', render_mode='human')
88
+ dp = DPAgent("FrozenLake-v1", is_slippery=False, desc=[
89
  "SFFFFFFF",
90
  "FFFFFFFH",
91
  "FFFHFFFF",
MonteCarloAgent.py CHANGED
@@ -2,12 +2,10 @@ import os
2
  import numpy as np
3
  import gymnasium as gym
4
  from tqdm import tqdm
5
- import argparse
6
- from gymnasium.envs.toy_text.frozen_lake import generate_random_map
7
  import wandb
8
  from .Shared import Shared
9
 
10
- class MonteCarloAgent(Shared):
11
  def __init__(
12
  self,
13
  **kwargs,
@@ -166,177 +164,3 @@ class MonteCarloAgent(Shared):
166
  }
167
  )
168
 
169
-
170
- def main():
171
- parser = argparse.ArgumentParser()
172
-
173
- ### Train/Test parameters
174
- parser.add_argument(
175
- "--train",
176
- action="store_true",
177
- help="Use this flag to train the agent.",
178
- )
179
- parser.add_argument(
180
- "--test",
181
- type=str,
182
- default=None,
183
- help="Use this flag to test the agent. Provide the path to the policy file.",
184
- )
185
- parser.add_argument(
186
- "--n_train_episodes",
187
- type=int,
188
- default=2500,
189
- help="The number of episodes to train for. (default: 2500)",
190
- )
191
- parser.add_argument(
192
- "--n_test_episodes",
193
- type=int,
194
- default=100,
195
- help="The number of episodes to test for. (default: 100)",
196
- )
197
- parser.add_argument(
198
- "--test_every",
199
- type=int,
200
- default=100,
201
- help="During training, test the agent every n episodes. (default: 100)",
202
- )
203
-
204
- parser.add_argument(
205
- "--max_steps",
206
- type=int,
207
- default=200,
208
- help="The maximum number of steps per episode before the episode is forced to end. (default: 200)",
209
- )
210
-
211
- parser.add_argument(
212
- "--update_type",
213
- type=str,
214
- choices=["first_visit", "every_visit"],
215
- default="first_visit",
216
- help="The type of update to use. (default: first_visit)",
217
- )
218
-
219
- parser.add_argument(
220
- "--save_dir",
221
- type=str,
222
- default="policies",
223
- help="The directory to save the policy to. (default: policies)",
224
- )
225
-
226
- parser.add_argument(
227
- "--no_save",
228
- action="store_true",
229
- help="Use this flag to disable saving the policy.",
230
- )
231
-
232
- ### Agent parameters
233
- parser.add_argument(
234
- "--gamma",
235
- type=float,
236
- default=1.0,
237
- help="The value for the discount factor to use. (default: 1.0)",
238
- )
239
- parser.add_argument(
240
- "--epsilon",
241
- type=float,
242
- default=0.4,
243
- help="The value for the epsilon-greedy policy to use. (default: 0.4)",
244
- )
245
-
246
- ### Environment parameters
247
- parser.add_argument(
248
- "--env",
249
- type=str,
250
- default="CliffWalking-v0",
251
- choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"],
252
- help="The Gymnasium environment to use. (default: CliffWalking-v0)",
253
- )
254
- parser.add_argument(
255
- "--render_mode",
256
- type=str,
257
- default=None,
258
- help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)",
259
- )
260
- parser.add_argument(
261
- "--wandb_project",
262
- type=str,
263
- default=None,
264
- help="WandB project name for logging. If not provided, no logging is done. (default: None)",
265
- )
266
- parser.add_argument(
267
- "--wandb_group",
268
- type=str,
269
- default="monte-carlo",
270
- help="WandB group name for logging. (default: monte-carlo)",
271
- )
272
- parser.add_argument(
273
- "--wandb_job_type",
274
- type=str,
275
- default="train",
276
- help="WandB job type for logging. (default: train)",
277
- )
278
- parser.add_argument(
279
- "--wandb_run_name_suffix",
280
- type=str,
281
- default=None,
282
- help="WandB run name suffix for logging. (default: None)",
283
- )
284
-
285
- args = parser.parse_args()
286
-
287
- agent = MonteCarloAgent(
288
- args.env,
289
- gamma=args.gamma,
290
- epsilon=args.epsilon,
291
- render_mode=args.render_mode,
292
- )
293
-
294
- run_name = f"{agent.__class__.__name__}_{args.env}_e{args.n_train_episodes}_s{args.max_steps}_g{args.gamma}_e{args.epsilon}_{args.update_type}"
295
- if args.wandb_run_name_suffix is not None:
296
- run_name += f"+{args.wandb_run_name_suffix}"
297
-
298
- agent.run_name = run_name
299
-
300
- try:
301
- if args.train:
302
- # Log to WandB
303
- if args.wandb_project is not None:
304
- wandb.init(
305
- project=args.wandb_project,
306
- name=run_name,
307
- group=args.wandb_group,
308
- job_type=args.wandb_job_type,
309
- config=dict(args._get_kwargs()),
310
- )
311
-
312
- agent.train(
313
- n_train_episodes=args.n_train_episodes,
314
- test_every=args.test_every,
315
- n_test_episodes=args.n_test_episodes,
316
- max_steps=args.max_steps,
317
- update_type=args.update_type,
318
- log_wandb=args.wandb_project is not None,
319
- save_best=True,
320
- save_best_dir=args.save_dir,
321
- )
322
- if not args.no_save:
323
- agent.save_policy(
324
- fname=f"{run_name}.npy",
325
- save_dir=args.save_dir,
326
- )
327
- elif args.test is not None:
328
- if not args.test.endswith(".npy"):
329
- args.test += ".npy"
330
- agent.load_policy(args.test)
331
- agent.test(
332
- n_test_episodes=args.n_test_episodes,
333
- max_steps=args.max_steps,
334
- )
335
- else:
336
- print("ERROR: Please provide either --train or --test.")
337
- except KeyboardInterrupt:
338
- print("Exiting...")
339
-
340
-
341
- if __name__ == "__main__":
342
- main()
 
2
  import numpy as np
3
  import gymnasium as gym
4
  from tqdm import tqdm
 
 
5
  import wandb
6
  from .Shared import Shared
7
 
8
+ class MCAgent(Shared):
9
  def __init__(
10
  self,
11
  **kwargs,
 
164
  }
165
  )
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
agents.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # All supported agents
2
+ from MonteCarloAgent import MCAgent
3
+ from DPAgent import DPAgent
4
+
5
+ AGENTS_MAP = {
6
+ "MCAgent": MCAgent,
7
+ "DPAgent": DPAgent
8
+ }
demo.py CHANGED
@@ -2,11 +2,12 @@ import os
2
  import time
3
  import numpy as np
4
  import gradio as gr
5
- from MonteCarloAgent import MonteCarloAgent
6
- from DPAgent import DP
7
  import scipy.ndimage
8
  import cv2
9
 
 
 
10
  default_n_test_episodes = 10
11
  default_max_steps = 500
12
  default_render_fps = 5
@@ -26,11 +27,7 @@ except FileNotFoundError:
26
  print("ERROR: No policies folder found!")
27
  all_policies = []
28
 
29
- # All supported agents
30
- agent_map = {
31
- "MonteCarloAgent": MonteCarloAgent,
32
- "DPAgent": DP
33
- }
34
  action_map = {
35
  "CliffWalking-v0": {
36
  0: "up",
@@ -127,7 +124,7 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
127
 
128
  agent_type, env_name = props[0], props[1]
129
 
130
- agent = agent_map[agent_type](env_name, render_mode="rgb_array")
131
  agent.load_policy(policy_path)
132
  env_action_map = action_map.get(env_name)
133
 
@@ -166,30 +163,6 @@ def run(policy_fname, n_test_episodes, max_steps, render_fps, epsilon):
166
  curr_policy -= np.min(curr_policy)
167
  curr_policy = curr_policy / np.sum(curr_policy)
168
 
169
- # frame_env = cv2.resize(
170
- # frame_env,
171
- # (
172
- # int(frame_env.shape[1] / frame_env.shape[0] * frame_env_h),
173
- # frame_env_h,
174
- # ),
175
- # interpolation=cv2.INTER_AREA,
176
- # )
177
-
178
- # if frame_env.shape[1] < frame_env_w:
179
- # rgb_array_new = np.pad(
180
- # frame_env,
181
- # (
182
- # (0, 0),
183
- # (
184
- # (frame_env_w - frame_env.shape[1]) // 2,
185
- # (frame_env_w - frame_env.shape[1]) // 2,
186
- # ),
187
- # (0, 0),
188
- # ),
189
- # "constant",
190
- # )
191
- # frame_env = np.uint8(rgb_array_new)
192
-
193
  frame_policy_h = frame_policy_res // len(curr_policy)
194
  frame_policy = np.zeros((frame_policy_h, frame_policy_res))
195
  for i, p in enumerate(curr_policy):
 
2
  import time
3
  import numpy as np
4
  import gradio as gr
5
+
 
6
  import scipy.ndimage
7
  import cv2
8
 
9
+ from agents import AGENTS_MAP
10
+
11
  default_n_test_episodes = 10
12
  default_max_steps = 500
13
  default_render_fps = 5
 
27
  print("ERROR: No policies folder found!")
28
  all_policies = []
29
 
30
+
 
 
 
 
31
  action_map = {
32
  "CliffWalking-v0": {
33
  0: "up",
 
124
 
125
  agent_type, env_name = props[0], props[1]
126
 
127
+ agent = AGENTS_MAP[agent_type](env_name, render_mode="rgb_array")
128
  agent.load_policy(policy_path)
129
  env_action_map = action_map.get(env_name)
130
 
 
163
  curr_policy -= np.min(curr_policy)
164
  curr_policy = curr_policy / np.sum(curr_policy)
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  frame_policy_h = frame_policy_res // len(curr_policy)
167
  frame_policy = np.zeros((frame_policy_h, frame_policy_res))
168
  for i, p in enumerate(curr_policy):
run.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import wandb
3
+
4
+ from agents import AGENTS_MAP
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser()
8
+
9
+ ### Train/Test parameters
10
+ parser.add_argument(
11
+ "--train",
12
+ action="store_true",
13
+ help="Use this flag to train the agent.",
14
+ )
15
+ parser.add_argument(
16
+ "--test",
17
+ type=str,
18
+ default=None,
19
+ help="Use this flag to test the agent. Provide the path to the policy file.",
20
+ )
21
+ parser.add_argument(
22
+ "--n_train_episodes",
23
+ type=int,
24
+ default=2500,
25
+ help="The number of episodes to train for. (default: 2500)",
26
+ )
27
+ parser.add_argument(
28
+ "--n_test_episodes",
29
+ type=int,
30
+ default=100,
31
+ help="The number of episodes to test for. (default: 100)",
32
+ )
33
+ parser.add_argument(
34
+ "--test_every",
35
+ type=int,
36
+ default=100,
37
+ help="During training, test the agent every n episodes. (default: 100)",
38
+ )
39
+
40
+ parser.add_argument(
41
+ "--max_steps",
42
+ type=int,
43
+ default=200,
44
+ help="The maximum number of steps per episode before the episode is forced to end. (default: 200)",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--update_type",
49
+ type=str,
50
+ choices=["first_visit", "every_visit"],
51
+ default="first_visit",
52
+ help="The type of update to use. (default: first_visit)",
53
+ )
54
+
55
+ parser.add_argument(
56
+ "--save_dir",
57
+ type=str,
58
+ default="policies",
59
+ help="The directory to save the policy to. (default: policies)",
60
+ )
61
+
62
+ parser.add_argument(
63
+ "--no_save",
64
+ action="store_true",
65
+ help="Use this flag to disable saving the policy.",
66
+ )
67
+
68
+ ### Agent parameters
69
+ parser.add_argument(
70
+ "--agent",
71
+ type=str,
72
+ required=True,
73
+ choices=AGENTS_MAP.keys(),
74
+ help=f"The agent to use. One of: {AGENTS_MAP.keys()}",
75
+ )
76
+
77
+ parser.add_argument(
78
+ "--gamma",
79
+ type=float,
80
+ default=1.0,
81
+ help="The value for the discount factor to use. (default: 1.0)",
82
+ )
83
+ parser.add_argument(
84
+ "--epsilon",
85
+ type=float,
86
+ default=0.4,
87
+ help="The value for the epsilon-greedy policy to use. (default: 0.4)",
88
+ )
89
+
90
+ ### Environment parameters
91
+ parser.add_argument(
92
+ "--env",
93
+ type=str,
94
+ default="CliffWalking-v0",
95
+ choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"],
96
+ help="The Gymnasium environment to use. (default: CliffWalking-v0)",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--render_mode",
101
+ type=str,
102
+ default=None,
103
+ help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)",
104
+ )
105
+ parser.add_argument(
106
+ "--wandb_project",
107
+ type=str,
108
+ default=None,
109
+ help="WandB project name for logging. If not provided, no logging is done. (default: None)",
110
+ )
111
+ parser.add_argument(
112
+ "--wandb_group",
113
+ type=str,
114
+ default="monte-carlo",
115
+ help="WandB group name for logging. (default: monte-carlo)",
116
+ )
117
+ parser.add_argument(
118
+ "--wandb_job_type",
119
+ type=str,
120
+ default="train",
121
+ help="WandB job type for logging. (default: train)",
122
+ )
123
+ parser.add_argument(
124
+ "--wandb_run_name_suffix",
125
+ type=str,
126
+ default=None,
127
+ help="WandB run name suffix for logging. (default: None)",
128
+ )
129
+
130
+ args = parser.parse_args()
131
+
132
+ agent = AGENTS_MAP[args.agent](
133
+ args.env,
134
+ gamma=args.gamma,
135
+ epsilon=args.epsilon,
136
+ render_mode=args.render_mode,
137
+ )
138
+
139
+ run_name = f"{agent.__class__.__name__}_{args.env}_e{args.n_train_episodes}_s{args.max_steps}_g{args.gamma}_e{args.epsilon}_{args.update_type}"
140
+ if args.wandb_run_name_suffix is not None:
141
+ run_name += f"+{args.wandb_run_name_suffix}"
142
+
143
+ agent.run_name = run_name
144
+
145
+ try:
146
+ if args.train:
147
+ # Log to WandB
148
+ if args.wandb_project is not None:
149
+ wandb.init(
150
+ project=args.wandb_project,
151
+ name=run_name,
152
+ group=args.wandb_group,
153
+ job_type=args.wandb_job_type,
154
+ config=dict(args._get_kwargs()),
155
+ )
156
+
157
+ agent.train(
158
+ n_train_episodes=args.n_train_episodes,
159
+ test_every=args.test_every,
160
+ n_test_episodes=args.n_test_episodes,
161
+ max_steps=args.max_steps,
162
+ update_type=args.update_type,
163
+ log_wandb=args.wandb_project is not None,
164
+ save_best=True,
165
+ save_best_dir=args.save_dir,
166
+ )
167
+ if not args.no_save:
168
+ agent.save_policy(
169
+ fname=f"{run_name}.npy",
170
+ save_dir=args.save_dir,
171
+ )
172
+ elif args.test is not None:
173
+ if not args.test.endswith(".npy"):
174
+ args.test += ".npy"
175
+ agent.load_policy(args.test)
176
+ agent.test(
177
+ n_test_episodes=args.n_test_episodes,
178
+ max_steps=args.max_steps,
179
+ )
180
+ else:
181
+ print("ERROR: Please provide either --train or --test.")
182
+ except KeyboardInterrupt:
183
+ print("Exiting...")
184
+
185
+
186
+ if __name__ == "__main__":
187
+ main()