Spaces:
Sleeping
Sleeping
File size: 5,371 Bytes
3e2038a 46b0409 3e2038a 99ac186 3e2038a 7d3766a 3e2038a 46b0409 3e2038a 99ac186 3e2038a 8ae24a2 46b0409 3e2038a 46b0409 30bb976 46b0409 30bb976 46b0409 3e2038a 6ee82fe 3e2038a 46b0409 3e2038a 8ae24a2 3e2038a e173b06 3e2038a 6ee82fe 46b0409 3e2038a 6ee82fe 434e854 3e2038a 6ee82fe 3e2038a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import argparse
import wandb
from utils import AGENTS_MAP, load_agent
def main():
parser = argparse.ArgumentParser()
### Train/Test parameters
parser.add_argument(
"--train",
action="store_true",
help="Use this flag to train the agent.",
)
parser.add_argument(
"--test",
type=str,
default=None,
help="Use this flag to test the agent. Provide the path to the policy file.",
)
parser.add_argument(
"--n_train_episodes",
type=int,
default=2500,
help="The number of episodes to train for. (default: 2500)",
)
parser.add_argument(
"--n_test_episodes",
type=int,
default=100,
help="The number of episodes to test for. (default: 100)",
)
parser.add_argument(
"--test_every",
type=int,
default=100,
help="During training, test the agent every n episodes. (default: 100)",
)
parser.add_argument(
"--max_steps",
type=int,
default=None,
help="The maximum number of steps per episode before the episode is forced to end. If not provided, defaults to the number of states in the environment. (default: None)",
)
### Agent parameters
parser.add_argument(
"--agent",
type=str,
required=True,
choices=AGENTS_MAP.keys(),
help=f"The agent to use. Currently supports one of: {list(AGENTS_MAP.keys())}",
)
parser.add_argument(
"--gamma",
type=float,
default=0.99,
help="The value for the discount factor to use. (default: 0.99)",
)
parser.add_argument(
"--epsilon",
type=float,
default=0.4,
help="The value for the epsilon-greedy policy to use. (default: 0.4)",
)
parser.add_argument(
"--type",
type=str,
choices=["onpolicy", "offpolicy"],
default="onpolicy",
help="The type of update to use. Only supported by Monte-Carlo agent. (default: onpolicy)",
)
### Environment parameters
parser.add_argument(
"--env",
type=str,
default="CliffWalking-v0",
choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"],
help="The Gymnasium environment to use. (default: CliffWalking-v0)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="The seed to use when generating the FrozenLake environment. If not provided, a random seed is used. (default: None)",
)
parser.add_argument(
"--size",
type=int,
default=8,
help="The size to use when generating the FrozenLake environment. (default: 8)",
)
parser.add_argument(
"--render_mode",
type=str,
default=None,
help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)",
)
# Logging and saving parameters
parser.add_argument(
"--save_dir",
type=str,
default="policies",
help="The directory to save the policy to. (default: policies)",
)
parser.add_argument(
"--no_save",
action="store_true",
help="Use this flag to disable saving the policy.",
)
parser.add_argument(
"--run_name_suffix",
type=str,
default=None,
help="Run name suffix for logging and policy checkpointing. (default: None)",
)
parser.add_argument(
"--wandb_project",
type=str,
default=None,
help="WandB project name for logging. If not provided, no logging is done. (default: None)",
)
parser.add_argument(
"--wandb_job_type",
type=str,
default="train",
help="WandB job type for logging. (default: train)",
)
args = parser.parse_args()
print(vars(args))
agent = load_agent(
args.agent if args.test is None else args.test, **dict(args._get_kwargs())
)
agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
if args.run_name_suffix is not None:
agent.run_name += f"+{args.run_name_suffix}"
try:
if args.train:
# Log to WandB
if args.wandb_project is not None:
wandb.init(
project=args.wandb_project,
name=agent.run_name,
group=args.agent,
job_type=args.wandb_job_type,
config=dict(args._get_kwargs()),
)
agent.train(
n_train_episodes=args.n_train_episodes,
test_every=args.test_every,
n_test_episodes=args.n_test_episodes,
max_steps=args.max_steps,
log_wandb=args.wandb_project is not None,
save_best=True,
save_best_dir=args.save_dir,
)
if not args.no_save:
agent.save_policy(save_dir=args.save_dir)
elif args.test is not None:
agent.test(
n_test_episodes=args.n_test_episodes,
max_steps=args.max_steps,
)
else:
print("ERROR: Please provide either --train or --test.")
except KeyboardInterrupt:
print("Exiting...")
if __name__ == "__main__":
main()
|