Spaces:
Sleeping
Sleeping
Andrei Cozma
commited on
Commit
·
7d3766a
1
Parent(s):
4a6d8ec
Updates
Browse files- AgentBase.py +7 -2
- run.py +2 -2
- test_params.py +33 -8
AgentBase.py
CHANGED
@@ -91,7 +91,12 @@ class AgentBase:
|
|
91 |
p=[1.0 - self.epsilon_override, self.epsilon_override],
|
92 |
)
|
93 |
|
94 |
-
def generate_episode(self, policy, max_steps=
|
|
|
|
|
|
|
|
|
|
|
95 |
state, _ = self.env.reset()
|
96 |
episode_hist, solved, done = [], False, False
|
97 |
rgb_array = self.env.render() if render else None
|
@@ -139,7 +144,7 @@ class AgentBase:
|
|
139 |
rgb_array = self.env.render() if render else None
|
140 |
yield episode_hist, solved, rgb_array
|
141 |
|
142 |
-
def run_episode(self, policy, max_steps=
|
143 |
# Run the generator until the end
|
144 |
episode_hist, solved, rgb_array = list(
|
145 |
self.generate_episode(policy, max_steps, render, **kwargs)
|
|
|
91 |
p=[1.0 - self.epsilon_override, self.epsilon_override],
|
92 |
)
|
93 |
|
94 |
+
def generate_episode(self, policy, max_steps=None, render=False, **kwargs):
|
95 |
+
if max_steps is None:
|
96 |
+
# If max_steps is not specified, we use a rough estimate of
|
97 |
+
# the maximum number of steps it should take to solve the environment
|
98 |
+
max_steps = self.n_states * self.n_actions
|
99 |
+
|
100 |
state, _ = self.env.reset()
|
101 |
episode_hist, solved, done = [], False, False
|
102 |
rgb_array = self.env.render() if render else None
|
|
|
144 |
rgb_array = self.env.render() if render else None
|
145 |
yield episode_hist, solved, rgb_array
|
146 |
|
147 |
+
def run_episode(self, policy, max_steps=None, render=False, **kwargs):
|
148 |
# Run the generator until the end
|
149 |
episode_hist, solved, rgb_array = list(
|
150 |
self.generate_episode(policy, max_steps, render, **kwargs)
|
run.py
CHANGED
@@ -39,8 +39,8 @@ def main():
|
|
39 |
parser.add_argument(
|
40 |
"--max_steps",
|
41 |
type=int,
|
42 |
-
default=
|
43 |
-
help="The maximum number of steps per episode before the episode is forced to end. (default:
|
44 |
)
|
45 |
|
46 |
### Agent parameters
|
|
|
39 |
parser.add_argument(
|
40 |
"--max_steps",
|
41 |
type=int,
|
42 |
+
default=None,
|
43 |
+
help="The maximum number of steps per episode before the episode is forced to end. If not provided, defaults to the number of states in the environment. (default: None)",
|
44 |
)
|
45 |
|
46 |
### Agent parameters
|
test_params.py
CHANGED
@@ -31,18 +31,21 @@ env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
|
|
31 |
agent = "MCAgent"
|
32 |
|
33 |
vals_update_type = [
|
34 |
-
"
|
|
|
35 |
] # Note: Every visit takes too long due to these environment's reward structure
|
36 |
-
vals_gamma = [1.0, 0.98, 0.96, 0.94]
|
37 |
vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
|
38 |
-
|
39 |
# vals_epsilon = [0.5]
|
40 |
|
|
|
|
|
41 |
if env == "CliffWalking-v0":
|
42 |
n_train_episodes = 2500
|
43 |
max_steps = 200
|
44 |
elif env == "FrozenLake-v1":
|
45 |
-
n_train_episodes =
|
46 |
max_steps = 200
|
47 |
elif env == "Taxi-v3":
|
48 |
n_train_episodes = 10000
|
@@ -53,9 +56,10 @@ else:
|
|
53 |
|
54 |
def run_test(args):
|
55 |
command = f"python3 run.py --train --agent {agent} --env {env}"
|
56 |
-
command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
|
57 |
-
command += f" --
|
58 |
-
|
|
|
59 |
if wandb_project is not None:
|
60 |
command += f" --wandb_project {wandb_project}"
|
61 |
command += " --no_save"
|
@@ -67,7 +71,28 @@ with multiprocessing.Pool(8) as p:
|
|
67 |
for update_type in vals_update_type:
|
68 |
for gamma in vals_gamma:
|
69 |
for eps in vals_epsilon:
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
random.shuffle(tests)
|
72 |
|
73 |
p.map(run_test, tests)
|
|
|
31 |
agent = "MCAgent"
|
32 |
|
33 |
vals_update_type = [
|
34 |
+
# "on_policy",
|
35 |
+
"off_policy",
|
36 |
] # Note: Every visit takes too long due to these environment's reward structure
|
37 |
+
# vals_gamma = [1.0, 0.98, 0.96, 0.94]
|
38 |
vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
|
39 |
+
vals_gamma = [1.0]
|
40 |
# vals_epsilon = [0.5]
|
41 |
|
42 |
+
vals_size = [8, 16, 32, 64]
|
43 |
+
|
44 |
if env == "CliffWalking-v0":
|
45 |
n_train_episodes = 2500
|
46 |
max_steps = 200
|
47 |
elif env == "FrozenLake-v1":
|
48 |
+
n_train_episodes = 25000
|
49 |
max_steps = 200
|
50 |
elif env == "Taxi-v3":
|
51 |
n_train_episodes = 10000
|
|
|
56 |
|
57 |
def run_test(args):
|
58 |
command = f"python3 run.py --train --agent {agent} --env {env}"
|
59 |
+
# command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
|
60 |
+
command += f" --n_train_episodes {n_train_episodes}"
|
61 |
+
for k, v in args.items():
|
62 |
+
command += f" --{k} {v}"
|
63 |
if wandb_project is not None:
|
64 |
command += f" --wandb_project {wandb_project}"
|
65 |
command += " --no_save"
|
|
|
71 |
for update_type in vals_update_type:
|
72 |
for gamma in vals_gamma:
|
73 |
for eps in vals_epsilon:
|
74 |
+
if env == "FrozenLake-v1":
|
75 |
+
for size in vals_size:
|
76 |
+
tests.extend(
|
77 |
+
{
|
78 |
+
"gamma": gamma,
|
79 |
+
"epsilon": eps,
|
80 |
+
"update_type": update_type,
|
81 |
+
"size": size,
|
82 |
+
"run_name_suffix": i,
|
83 |
+
}
|
84 |
+
for i in range(num_tests)
|
85 |
+
)
|
86 |
+
else:
|
87 |
+
tests.extend(
|
88 |
+
{
|
89 |
+
"gamma": gamma,
|
90 |
+
"epsilon": eps,
|
91 |
+
"update_type": update_type,
|
92 |
+
"run_name_suffix": i,
|
93 |
+
}
|
94 |
+
for i in range(num_tests)
|
95 |
+
)
|
96 |
random.shuffle(tests)
|
97 |
|
98 |
p.map(run_test, tests)
|