File size: 5,371 Bytes
3e2038a
 
 
46b0409
3e2038a
99ac186
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d3766a
 
3e2038a
 
 
 
 
 
 
 
46b0409
3e2038a
99ac186
3e2038a
 
 
8ae24a2
46b0409
3e2038a
 
 
 
 
 
 
 
46b0409
30bb976
46b0409
30bb976
 
 
46b0409
 
3e2038a
 
 
 
 
 
 
 
6ee82fe
 
 
 
 
 
 
 
 
 
 
 
3e2038a
 
 
 
 
 
46b0409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae24a2
3e2038a
e173b06
 
 
3e2038a
6ee82fe
46b0409
 
3e2038a
 
 
 
 
 
 
6ee82fe
434e854
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
 
6ee82fe
3e2038a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import argparse
import wandb

from utils import AGENTS_MAP, load_agent


def main():
    parser = argparse.ArgumentParser()
    ### Train/Test parameters
    parser.add_argument(
        "--train",
        action="store_true",
        help="Use this flag to train the agent.",
    )
    parser.add_argument(
        "--test",
        type=str,
        default=None,
        help="Use this flag to test the agent. Provide the path to the policy file.",
    )
    parser.add_argument(
        "--n_train_episodes",
        type=int,
        default=2500,
        help="The number of episodes to train for. (default: 2500)",
    )
    parser.add_argument(
        "--n_test_episodes",
        type=int,
        default=100,
        help="The number of episodes to test for. (default: 100)",
    )
    parser.add_argument(
        "--test_every",
        type=int,
        default=100,
        help="During training, test the agent every n episodes. (default: 100)",
    )
    parser.add_argument(
        "--max_steps",
        type=int,
        default=None,
        help="The maximum number of steps per episode before the episode is forced to end. If not provided, defaults to the number of states in the environment. (default: None)",
    )

    ### Agent parameters
    parser.add_argument(
        "--agent",
        type=str,
        required=True,
        choices=AGENTS_MAP.keys(),
        help=f"The agent to use. Currently supports one of: {list(AGENTS_MAP.keys())}",
    )

    parser.add_argument(
        "--gamma",
        type=float,
        default=0.99,
        help="The value for the discount factor to use. (default: 0.99)",
    )
    parser.add_argument(
        "--epsilon",
        type=float,
        default=0.4,
        help="The value for the epsilon-greedy policy to use. (default: 0.4)",
    )

    parser.add_argument(
        "--type",
        type=str,
        choices=["onpolicy", "offpolicy"],
        default="onpolicy",
        help="The type of update to use. Only supported by Monte-Carlo agent. (default: onpolicy)",
    )

    ### Environment parameters
    parser.add_argument(
        "--env",
        type=str,
        default="CliffWalking-v0",
        choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"],
        help="The Gymnasium environment to use. (default: CliffWalking-v0)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=None,
        help="The seed to use when generating the FrozenLake environment. If not provided, a random seed is used. (default: None)",
    )
    parser.add_argument(
        "--size",
        type=int,
        default=8,
        help="The size to use when generating the FrozenLake environment. (default: 8)",
    )
    parser.add_argument(
        "--render_mode",
        type=str,
        default=None,
        help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)",
    )

    # Logging and saving parameters
    parser.add_argument(
        "--save_dir",
        type=str,
        default="policies",
        help="The directory to save the policy to. (default: policies)",
    )
    parser.add_argument(
        "--no_save",
        action="store_true",
        help="Use this flag to disable saving the policy.",
    )
    parser.add_argument(
        "--run_name_suffix",
        type=str,
        default=None,
        help="Run name suffix for logging and policy checkpointing. (default: None)",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default=None,
        help="WandB project name for logging. If not provided, no logging is done. (default: None)",
    )
    parser.add_argument(
        "--wandb_job_type",
        type=str,
        default="train",
        help="WandB job type for logging. (default: train)",
    )

    args = parser.parse_args()
    print(vars(args))

    agent = load_agent(
        args.agent if args.test is None else args.test, **dict(args._get_kwargs())
    )

    agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
    if args.run_name_suffix is not None:
        agent.run_name += f"+{args.run_name_suffix}"

    try:
        if args.train:
            # Log to WandB
            if args.wandb_project is not None:
                wandb.init(
                    project=args.wandb_project,
                    name=agent.run_name,
                    group=args.agent,
                    job_type=args.wandb_job_type,
                    config=dict(args._get_kwargs()),
                )

            agent.train(
                n_train_episodes=args.n_train_episodes,
                test_every=args.test_every,
                n_test_episodes=args.n_test_episodes,
                max_steps=args.max_steps,
                log_wandb=args.wandb_project is not None,
                save_best=True,
                save_best_dir=args.save_dir,
            )
            if not args.no_save:
                agent.save_policy(save_dir=args.save_dir)
        elif args.test is not None:
            agent.test(
                n_test_episodes=args.n_test_episodes,
                max_steps=args.max_steps,
            )
        else:
            print("ERROR: Please provide either --train or --test.")
    except KeyboardInterrupt:
        print("Exiting...")


if __name__ == "__main__":
    main()