Andrei Cozma commited on
Commit
0f41753
·
1 Parent(s): adada5a
Files changed (2) hide show
  1. demo.py +1 -1
  2. test_params.py +59 -62
demo.py CHANGED
@@ -79,7 +79,7 @@ def reset_change(state, policy_fname):
79
 
80
 
81
  def reset_click(state):
82
- state.should_reset = True
83
  state.live_paused = default_paused
84
  state.live_render_fps = default_render_fps
85
  state.live_epsilon = default_epsilon
 
79
 
80
 
81
  def reset_click(state):
82
+ state.should_reset = state.current_policy is not None
83
  state.live_paused = default_paused
84
  state.live_render_fps = default_render_fps
85
  state.live_epsilon = default_epsilon
test_params.py CHANGED
@@ -4,73 +4,70 @@ import multiprocessing
4
  import random
5
 
6
 
7
- def run(args):
8
- env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
9
- agent = "MCAgent"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- vals_update_type = [
12
- "first_visit"
13
- ] # Note: Every visit takes too long due to these environment's reward structure
14
- vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
15
- vals_gamma = [1.0, 0.98, 0.96, 0.94]
16
-
17
- if env == "Taxi-v3":
18
- n_train_episodes = 10000
19
- max_steps = 500
20
- elif env == "FrozenLake-v1":
21
- n_train_episodes = 5000
22
- max_steps = 200
23
- elif env == "CliffWalking-v0":
24
- n_train_episodes = 2500
25
- max_steps = 200
26
- else:
27
- raise ValueError(f"Unsupported environment: {env}")
28
 
 
 
29
 
30
- def run_test(args):
31
- command = f"python3 run.py --train --agent {agent} --env {env}"
32
- command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
33
- command += f" --gamma {args[0]} --epsilon {args[1]} --update_type {args[2]}"
34
- command += f" --run_name_suffix {args[3]}"
35
- if wandb_project is not None:
36
- command += f" --wandb_project {wandb_project}"
37
- command += " --no_save"
38
- os.system(command)
39
 
40
- with multiprocessing.Pool(8) as p:
41
- tests = []
42
- for update_type in vals_update_type:
43
- for gamma in vals_gamma:
44
- for eps in vals_epsilon:
45
- tests.extend((gamma, eps, update_type, i) for i in range(num_tests))
46
- random.shuffle(tests)
 
 
 
 
47
 
48
- p.map(run_test, tests)
49
 
 
 
 
 
 
 
 
 
 
50
 
51
- def main():
52
- # argument parsing
53
- parser = argparse.ArgumentParser(description="Run parameter tests for MC agent")
54
- parser.add_argument(
55
- "--env",
56
- type=str,
57
- default="Taxi-v3",
58
- help="environment to run",
59
- )
60
- parser.add_argument(
61
- "--num_tests",
62
- type=int,
63
- default=10,
64
- help="number of tests to run for each parameter combination",
65
- )
66
- parser.add_argument(
67
- "--wandb_project",
68
- type=str,
69
- default=None,
70
- help="wandb project name to log to",
71
- )
72
-
73
- args = parser.parse_args()
74
-
75
- run(args)
76
 
 
 
 
 
 
 
 
 
 
 
4
  import random
5
 
6
 
7
+ # argument parsing
8
+ parser = argparse.ArgumentParser(description="Run parameter tests for MC agent")
9
+ parser.add_argument(
10
+ "--env",
11
+ type=str,
12
+ default="Taxi-v3",
13
+ help="environment to run",
14
+ )
15
+ parser.add_argument(
16
+ "--num_tests",
17
+ type=int,
18
+ default=25,
19
+ help="number of tests to run for each parameter combination",
20
+ )
21
+ parser.add_argument(
22
+ "--wandb_project",
23
+ type=str,
24
+ default=None,
25
+ help="wandb project name to log to",
26
+ )
27
 
28
+ args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
31
+ agent = "MCAgent"
32
 
33
+ vals_update_type = [
34
+ "first_visit"
35
+ ] # Note: Every visit takes too long due to these environment's reward structure
36
+ vals_gamma = [1.0, 0.98, 0.96, 0.94]
37
+ vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
38
+ # vals_gamma = [1.0]
39
+ # vals_epsilon = [0.5]
 
 
40
 
41
+ if env == "CliffWalking-v0":
42
+ n_train_episodes = 2500
43
+ max_steps = 200
44
+ elif env == "FrozenLake-v1":
45
+ n_train_episodes = 5000
46
+ max_steps = 200
47
+ elif env == "Taxi-v3":
48
+ n_train_episodes = 10000
49
+ max_steps = 500
50
+ else:
51
+ raise ValueError(f"Unsupported environment: {env}")
52
 
 
53
 
54
+ def run_test(args):
55
+ command = f"python3 run.py --train --agent {agent} --env {env}"
56
+ command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
57
+ command += f" --gamma {args[0]} --epsilon {args[1]} --update_type {args[2]}"
58
+ command += f" --run_name_suffix {args[3]}"
59
+ if wandb_project is not None:
60
+ command += f" --wandb_project {wandb_project}"
61
+ command += " --no_save"
62
+ os.system(command)
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ with multiprocessing.Pool(8) as p:
66
+ tests = []
67
+ for update_type in vals_update_type:
68
+ for gamma in vals_gamma:
69
+ for eps in vals_epsilon:
70
+ tests.extend((gamma, eps, update_type, i) for i in range(num_tests))
71
+ random.shuffle(tests)
72
+
73
+ p.map(run_test, tests)