Andrei Cozma commited on
Commit
7d3766a
·
1 Parent(s): 4a6d8ec
Files changed (3) hide show
  1. AgentBase.py +7 -2
  2. run.py +2 -2
  3. test_params.py +33 -8
AgentBase.py CHANGED
@@ -91,7 +91,12 @@ class AgentBase:
91
  p=[1.0 - self.epsilon_override, self.epsilon_override],
92
  )
93
 
94
- def generate_episode(self, policy, max_steps=500, render=False, **kwargs):
 
 
 
 
 
95
  state, _ = self.env.reset()
96
  episode_hist, solved, done = [], False, False
97
  rgb_array = self.env.render() if render else None
@@ -139,7 +144,7 @@ class AgentBase:
139
  rgb_array = self.env.render() if render else None
140
  yield episode_hist, solved, rgb_array
141
 
142
- def run_episode(self, policy, max_steps=500, render=False, **kwargs):
143
  # Run the generator until the end
144
  episode_hist, solved, rgb_array = list(
145
  self.generate_episode(policy, max_steps, render, **kwargs)
 
91
  p=[1.0 - self.epsilon_override, self.epsilon_override],
92
  )
93
 
94
+ def generate_episode(self, policy, max_steps=None, render=False, **kwargs):
95
+ if max_steps is None:
96
+ # If max_steps is not specified, we use a rough estimate of
97
+ # the maximum number of steps it should take to solve the environment
98
+ max_steps = self.n_states * self.n_actions
99
+
100
  state, _ = self.env.reset()
101
  episode_hist, solved, done = [], False, False
102
  rgb_array = self.env.render() if render else None
 
144
  rgb_array = self.env.render() if render else None
145
  yield episode_hist, solved, rgb_array
146
 
147
+ def run_episode(self, policy, max_steps=None, render=False, **kwargs):
148
  # Run the generator until the end
149
  episode_hist, solved, rgb_array = list(
150
  self.generate_episode(policy, max_steps, render, **kwargs)
run.py CHANGED
@@ -39,8 +39,8 @@ def main():
39
  parser.add_argument(
40
  "--max_steps",
41
  type=int,
42
- default=200,
43
- help="The maximum number of steps per episode before the episode is forced to end. (default: 200)",
44
  )
45
 
46
  ### Agent parameters
 
39
  parser.add_argument(
40
  "--max_steps",
41
  type=int,
42
+ default=None,
43
+ help="The maximum number of steps per episode before the episode is forced to end. If not provided, defaults to the number of states in the environment. (default: None)",
44
  )
45
 
46
  ### Agent parameters
test_params.py CHANGED
@@ -31,18 +31,21 @@ env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
31
  agent = "MCAgent"
32
 
33
  vals_update_type = [
34
- "first_visit"
 
35
  ] # Note: Every visit takes too long due to these environment's reward structure
36
- vals_gamma = [1.0, 0.98, 0.96, 0.94]
37
  vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
38
- # vals_gamma = [1.0]
39
  # vals_epsilon = [0.5]
40
 
 
 
41
  if env == "CliffWalking-v0":
42
  n_train_episodes = 2500
43
  max_steps = 200
44
  elif env == "FrozenLake-v1":
45
- n_train_episodes = 5000
46
  max_steps = 200
47
  elif env == "Taxi-v3":
48
  n_train_episodes = 10000
@@ -53,9 +56,10 @@ else:
53
 
54
  def run_test(args):
55
  command = f"python3 run.py --train --agent {agent} --env {env}"
56
- command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
57
- command += f" --gamma {args[0]} --epsilon {args[1]} --update_type {args[2]}"
58
- command += f" --run_name_suffix {args[3]}"
 
59
  if wandb_project is not None:
60
  command += f" --wandb_project {wandb_project}"
61
  command += " --no_save"
@@ -67,7 +71,28 @@ with multiprocessing.Pool(8) as p:
67
  for update_type in vals_update_type:
68
  for gamma in vals_gamma:
69
  for eps in vals_epsilon:
70
- tests.extend((gamma, eps, update_type, i) for i in range(num_tests))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  random.shuffle(tests)
72
 
73
  p.map(run_test, tests)
 
31
  agent = "MCAgent"
32
 
33
  vals_update_type = [
34
+ # "on_policy",
35
+ "off_policy",
36
  ] # Note: Every visit takes too long due to these environment's reward structure
37
+ # vals_gamma = [1.0, 0.98, 0.96, 0.94]
38
  vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
39
+ vals_gamma = [1.0]
40
  # vals_epsilon = [0.5]
41
 
42
+ vals_size = [8, 16, 32, 64]
43
+
44
  if env == "CliffWalking-v0":
45
  n_train_episodes = 2500
46
  max_steps = 200
47
  elif env == "FrozenLake-v1":
48
+ n_train_episodes = 25000
49
  max_steps = 200
50
  elif env == "Taxi-v3":
51
  n_train_episodes = 10000
 
56
 
57
  def run_test(args):
58
  command = f"python3 run.py --train --agent {agent} --env {env}"
59
+ # command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
60
+ command += f" --n_train_episodes {n_train_episodes}"
61
+ for k, v in args.items():
62
+ command += f" --{k} {v}"
63
  if wandb_project is not None:
64
  command += f" --wandb_project {wandb_project}"
65
  command += " --no_save"
 
71
  for update_type in vals_update_type:
72
  for gamma in vals_gamma:
73
  for eps in vals_epsilon:
74
+ if env == "FrozenLake-v1":
75
+ for size in vals_size:
76
+ tests.extend(
77
+ {
78
+ "gamma": gamma,
79
+ "epsilon": eps,
80
+ "update_type": update_type,
81
+ "size": size,
82
+ "run_name_suffix": i,
83
+ }
84
+ for i in range(num_tests)
85
+ )
86
+ else:
87
+ tests.extend(
88
+ {
89
+ "gamma": gamma,
90
+ "epsilon": eps,
91
+ "update_type": update_type,
92
+ "run_name_suffix": i,
93
+ }
94
+ for i in range(num_tests)
95
+ )
96
  random.shuffle(tests)
97
 
98
  p.map(run_test, tests)