File size: 2,788 Bytes
46b0409
 
 
 
 
 
0f41753
 
 
 
 
120dc90
0f41753
 
 
 
 
120dc90
0f41753
 
 
 
 
 
 
 
46b0409
0f41753
46b0409
0f41753
 
46b0409
0f41753
7d3766a
 
0f41753
7d3766a
0f41753
7d3766a
0f41753
46b0409
7d3766a
 
0f41753
 
120dc90
0f41753
7d3766a
120dc90
0f41753
 
120dc90
0f41753
 
46b0409
 
0f41753
 
7d3766a
 
 
 
0f41753
 
 
 
46b0409
 
0f41753
 
 
 
 
7d3766a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f41753
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
import multiprocessing
import random


# argument parsing
parser = argparse.ArgumentParser(description="Run parameter tests for MC agent")
parser.add_argument(
    "--env",
    type=str,
    default="FrozenLake-v1",
    help="environment to run",
)
parser.add_argument(
    "--num_tests",
    type=int,
    default=10,
    help="number of tests to run for each parameter combination",
)
parser.add_argument(
    "--wandb_project",
    type=str,
    default=None,
    help="wandb project name to log to",
)

args = parser.parse_args()

env, num_tests, wandb_project = args.env, args.num_tests, args.wandb_project
agent = "MCAgent"

vals_update_type = [
    # "on_policy",
    "off_policy",
]  # Note: Every visit takes too long due to these environment's reward structure
# vals_gamma = [1.0, 0.98, 0.96, 0.94]
vals_epsilon = [0.1, 0.2, 0.3, 0.4, 0.5]
vals_gamma = [1.0]
# vals_epsilon = [0.5]

vals_size = [8, 16, 32, 64]

if env == "CliffWalking-v0":
    n_train_episodes = 2500
    # max_steps = 200
elif env == "FrozenLake-v1":
    n_train_episodes = 25000
    # max_steps = 200
elif env == "Taxi-v3":
    n_train_episodes = 10000
    # max_steps = 500
else:
    raise ValueError(f"Unsupported environment: {env}")


def run_test(args):
    command = f"python3 run.py --train --agent {agent} --env {env}"
    # command += f" --n_train_episodes {n_train_episodes} --max_steps {max_steps}"
    command += f" --n_train_episodes {n_train_episodes}"
    for k, v in args.items():
        command += f" --{k} {v}"
    if wandb_project is not None:
        command += f" --wandb_project {wandb_project}"
    command += " --no_save"
    os.system(command)


with multiprocessing.Pool(8) as p:
    tests = []
    for update_type in vals_update_type:
        for gamma in vals_gamma:
            for eps in vals_epsilon:
                if env == "FrozenLake-v1":
                    for size in vals_size:
                        tests.extend(
                            {
                                "gamma": gamma,
                                "epsilon": eps,
                                "update_type": update_type,
                                "size": size,
                                "run_name_suffix": i,
                            }
                            for i in range(num_tests)
                        )
                else:
                    tests.extend(
                        {
                            "gamma": gamma,
                            "epsilon": eps,
                            "update_type": update_type,
                            "run_name_suffix": i,
                        }
                        for i in range(num_tests)
                    )
    random.shuffle(tests)

    p.map(run_test, tests)