Spaces:
Sleeping
Sleeping
File size: 2,788 Bytes
46b0409 0f41753 120dc90 0f41753 120dc90 0f41753 46b0409 0f41753 46b0409 0f41753 46b0409 0f41753 7d3766a 0f41753 7d3766a 0f41753 7d3766a 0f41753 46b0409 7d3766a 0f41753 120dc90 0f41753 7d3766a 120dc90 0f41753 120dc90 0f41753 46b0409 0f41753 7d3766a 0f41753 46b0409 0f41753 7d3766a 0f41753 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import argparse
import os
import multiprocessing
import random
# argument parsing
parser = argparse.ArgumentParser(description="Run parameter tests for MC agent")
parser.add_argument(
"--env",
type=str,
default="FrozenLake-v1",
help="environment to run",
)
parser.add_argument(
"--num_tests",
type=int,
default=10,
help="number of tests to run for each parameter combination",
)
parser.add_argument(
"--wandb_project",
type=str,
default=None,
help="wandb project name to log to",
)
args = parser.parse_args()
env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
agent = "MCAgent"
vals_update_type = [
# "on_policy",
"off_policy",
] # Note: Every visit takes too long due to these environment's reward structure
# vals_gamma = [1.0, 0.98, 0.96, 0.94]
vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
vals_gamma = [1.0]
# vals_epsilon = [0.5]
vals_size = [8, 16, 32, 64]
if env == "CliffWalking-v0":
n_train_episodes = 2500
# max_steps = 200
elif env == "FrozenLake-v1":
n_train_episodes = 25000
# max_steps = 200
elif env == "Taxi-v3":
n_train_episodes = 10000
# max_steps = 500
else:
raise ValueError(f"Unsupported environment: {env}")
def run_test(args):
command = f"python3 run.py --train --agent {agent} --env {env}"
# command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
command += f" --n_train_episodes {n_train_episodes}"
for k, v in args.items():
command += f" --{k} {v}"
if wandb_project is not None:
command += f" --wandb_project {wandb_project}"
command += " --no_save"
os.system(command)
with multiprocessing.Pool(8) as p:
tests = []
for update_type in vals_update_type:
for gamma in vals_gamma:
for eps in vals_epsilon:
if env == "FrozenLake-v1":
for size in vals_size:
tests.extend(
{
"gamma": gamma,
"epsilon": eps,
"update_type": update_type,
"size": size,
"run_name_suffix": i,
}
for i in range(num_tests)
)
else:
tests.extend(
{
"gamma": gamma,
"epsilon": eps,
"update_type": update_type,
"run_name_suffix": i,
}
for i in range(num_tests)
)
random.shuffle(tests)
p.map(run_test, tests)
|