File size: 5,297 Bytes
3e2038a
 
 
6ee82fe
3e2038a
99ac186
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3af403
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99ac186
3e2038a
 
 
8ae24a2
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99ac186
6ee82fe
 
 
 
 
 
 
 
 
 
 
 
 
 
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae24a2
3e2038a
6ee82fe
3e2038a
6ee82fe
 
 
3e2038a
 
 
 
 
 
 
6ee82fe
434e854
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ee82fe
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import argparse
import wandb

from agents import AGENTS_MAP, load_agent


def main():
    parser = argparse.ArgumentParser()
    ### Train/Test parameters
    parser.add_argument(
        "--train",
        action="store_true",
        help="Use this flag to train the agent.",
    )
    parser.add_argument(
        "--test",
        type=str,
        default=None,
        help="Use this flag to test the agent. Provide the path to the policy file.",
    )
    parser.add_argument(
        "--n_train_episodes",
        type=int,
        default=2500,
        help="The number of episodes to train for. (default: 2500)",
    )
    parser.add_argument(
        "--n_test_episodes",
        type=int,
        default=100,
        help="The number of episodes to test for. (default: 100)",
    )
    parser.add_argument(
        "--test_every",
        type=int,
        default=100,
        help="During training, test the agent every n episodes. (default: 100)",
    )

    parser.add_argument(
        "--max_steps",
        type=int,
        default=200,
        help="The maximum number of steps per episode before the episode is forced to end. (default: 200)",
    )

    parser.add_argument(
        "--update_type",
        type=str,
        choices=["first_visit", "every_visit"],
        default="first_visit",
        help="The type of update to use. Only supported by Monte-Carlo agent. (default: first_visit)",
    )

    parser.add_argument(
        "--save_dir",
        type=str,
        default="policies",
        help="The directory to save the policy to. (default: policies)",
    )

    parser.add_argument(
        "--no_save",
        action="store_true",
        help="Use this flag to disable saving the policy.",
    )

    ### Agent parameters
    parser.add_argument(
        "--agent",
        type=str,
        required=True,
        choices=AGENTS_MAP.keys(),
        help=f"The agent to use. One of: {AGENTS_MAP.keys()}",
    )

    parser.add_argument(
        "--gamma",
        type=float,
        default=0.99,
        help="The value for the discount factor to use. (default: 1.0)",
    )
    parser.add_argument(
        "--epsilon",
        type=float,
        default=0.4,
        help="The value for the epsilon-greedy policy to use. (default: 0.4)",
    )

    ### Environment parameters
    parser.add_argument(
        "--env",
        type=str,
        default="CliffWalking-v0",
        choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"],
        help="The Gymnasium environment to use. (default: CliffWalking-v0)",
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="The seed to use when generating the FrozenLake environment. If not provided, a random seed is used. (default: None)",
    )

    parser.add_argument(
        "--size",
        type=int,
        default=8,
        help="The size to use when generating the FrozenLake environment. (default: 8)",
    )

    parser.add_argument(
        "--render_mode",
        type=str,
        default=None,
        help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default=None,
        help="WandB project name for logging. If not provided, no logging is done. (default: None)",
    )
    parser.add_argument(
        "--wandb_job_type",
        type=str,
        default="train",
        help="WandB job type for logging. (default: train)",
    )
    parser.add_argument(
        "--wandb_run_name_suffix",
        type=str,
        default=None,
        help="WandB run name suffix for logging. (default: None)",
    )

    args = parser.parse_args()
    print(vars(args))

    agent = load_agent(args.agent, **dict(args._get_kwargs()))

    agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
    if args.wandb_run_name_suffix is not None:
        agent.run_name += f"+{args.wandb_run_name_suffix}"

    try:
        if args.train:
            # Log to WandB
            if args.wandb_project is not None:
                wandb.init(
                    project=args.wandb_project,
                    name=agent.run_name,
                    group=args.agent,
                    job_type=args.wandb_job_type,
                    config=dict(args._get_kwargs()),
                )

            agent.train(
                n_train_episodes=args.n_train_episodes,
                test_every=args.test_every,
                n_test_episodes=args.n_test_episodes,
                max_steps=args.max_steps,
                update_type=args.update_type,
                log_wandb=args.wandb_project is not None,
                save_best=True,
                save_best_dir=args.save_dir,
            )
            if not args.no_save:
                agent.save_policy(save_dir=args.save_dir)
        elif args.test is not None:
            agent.load_policy(args.test)
            agent.test(
                n_test_episodes=args.n_test_episodes,
                max_steps=args.max_steps,
            )
        else:
            print("ERROR: Please provide either --train or --test.")
    except KeyboardInterrupt:
        print("Exiting...")


if __name__ == "__main__":
    main()