Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
0f41753
1
Parent(s):
adada5a
Updates
Browse files- demo.py +1 -1
- test_params.py +59 -62
demo.py
CHANGED
@@ -79,7 +79,7 @@ def reset_change(state, policy_fname):
|
|
79 |
|
80 |
|
81 |
def reset_click(state):
|
82 |
-
state.should_reset =
|
83 |
state.live_paused = default_paused
|
84 |
state.live_render_fps = default_render_fps
|
85 |
state.live_epsilon = default_epsilon
|
|
|
79 |
|
80 |
|
81 |
def reset_click(state):
|
82 |
+
state.should_reset = state.current_policy is not None
|
83 |
state.live_paused = default_paused
|
84 |
state.live_render_fps = default_render_fps
|
85 |
state.live_epsilon = default_epsilon
|
test_params.py
CHANGED
@@ -4,73 +4,70 @@ import multiprocessing
|
|
4 |
import random
|
5 |
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
"first_visit"
|
13 |
-
] # Note: Every visit takes too long due to these environment's reward structure
|
14 |
-
vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
|
15 |
-
vals_gamma = [1.0, 0.98, 0.96, 0.94]
|
16 |
-
|
17 |
-
if env == "Taxi-v3":
|
18 |
-
n_train_episodes = 10000
|
19 |
-
max_steps = 500
|
20 |
-
elif env == "FrozenLake-v1":
|
21 |
-
n_train_episodes = 5000
|
22 |
-
max_steps = 200
|
23 |
-
elif env == "CliffWalking-v0":
|
24 |
-
n_train_episodes = 2500
|
25 |
-
max_steps = 200
|
26 |
-
else:
|
27 |
-
raise ValueError(f"Unsupported environment: {env}")
|
28 |
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
command += " --no_save"
|
38 |
-
os.system(command)
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
p.map(run_test, tests)
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
def main():
|
52 |
-
# argument parsing
|
53 |
-
parser = argparse.ArgumentParser(description="Run parameter tests for MC agent")
|
54 |
-
parser.add_argument(
|
55 |
-
"--env",
|
56 |
-
type=str,
|
57 |
-
default="Taxi-v3",
|
58 |
-
help="environment to run",
|
59 |
-
)
|
60 |
-
parser.add_argument(
|
61 |
-
"--num_tests",
|
62 |
-
type=int,
|
63 |
-
default=10,
|
64 |
-
help="number of tests to run for each parameter combination",
|
65 |
-
)
|
66 |
-
parser.add_argument(
|
67 |
-
"--wandb_project",
|
68 |
-
type=str,
|
69 |
-
default=None,
|
70 |
-
help="wandb project name to log to",
|
71 |
-
)
|
72 |
-
|
73 |
-
args = parser.parse_args()
|
74 |
-
|
75 |
-
run(args)
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import random
|
5 |
|
6 |
|
7 |
+
# argument parsing
|
8 |
+
parser = argparse.ArgumentParser(description="Run parameter tests for MC agent")
|
9 |
+
parser.add_argument(
|
10 |
+
"--env",
|
11 |
+
type=str,
|
12 |
+
default="Taxi-v3",
|
13 |
+
help="environment to run",
|
14 |
+
)
|
15 |
+
parser.add_argument(
|
16 |
+
"--num_tests",
|
17 |
+
type=int,
|
18 |
+
default=25,
|
19 |
+
help="number of tests to run for each parameter combination",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--wandb_project",
|
23 |
+
type=str,
|
24 |
+
default=None,
|
25 |
+
help="wandb project name to log to",
|
26 |
+
)
|
27 |
|
28 |
+
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
|
31 |
+
agent = "MCAgent"
|
32 |
|
33 |
+
vals_update_type = [
|
34 |
+
"first_visit"
|
35 |
+
] # Note: Every visit takes too long due to these environment's reward structure
|
36 |
+
vals_gamma = [1.0, 0.98, 0.96, 0.94]
|
37 |
+
vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
|
38 |
+
# vals_gamma = [1.0]
|
39 |
+
# vals_epsilon = [0.5]
|
|
|
|
|
40 |
|
41 |
+
if env == "CliffWalking-v0":
|
42 |
+
n_train_episodes = 2500
|
43 |
+
max_steps = 200
|
44 |
+
elif env == "FrozenLake-v1":
|
45 |
+
n_train_episodes = 5000
|
46 |
+
max_steps = 200
|
47 |
+
elif env == "Taxi-v3":
|
48 |
+
n_train_episodes = 10000
|
49 |
+
max_steps = 500
|
50 |
+
else:
|
51 |
+
raise ValueError(f"Unsupported environment: {env}")
|
52 |
|
|
|
53 |
|
54 |
+
def run_test(args):
|
55 |
+
command = f"python3 run.py --train --agent {agent} --env {env}"
|
56 |
+
command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
|
57 |
+
command += f" --gamma {args[0]} --epsilon {args[1]} --update_type {args[2]}"
|
58 |
+
command += f" --run_name_suffix {args[3]}"
|
59 |
+
if wandb_project is not None:
|
60 |
+
command += f" --wandb_project {wandb_project}"
|
61 |
+
command += " --no_save"
|
62 |
+
os.system(command)
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
with multiprocessing.Pool(8) as p:
|
66 |
+
tests = []
|
67 |
+
for update_type in vals_update_type:
|
68 |
+
for gamma in vals_gamma:
|
69 |
+
for eps in vals_epsilon:
|
70 |
+
tests.extend((gamma, eps, update_type, i) for i in range(num_tests))
|
71 |
+
random.shuffle(tests)
|
72 |
+
|
73 |
+
p.map(run_test, tests)
|