Andrei Cozma
Updates
30bb976
import argparse
import wandb
from utils import AGENTS_MAP, load_agent
def main():
parser = argparse.ArgumentParser()
### Train/Test parameters
parser.add_argument(
"--train",
action="store_true",
help="Use this flag to train the agent.",
)
parser.add_argument(
"--test",
type=str,
default=None,
help="Use this flag to test the agent. Provide the path to the policy file.",
)
parser.add_argument(
"--n_train_episodes",
type=int,
default=2500,
help="The number of episodes to train for. (default: 2500)",
)
parser.add_argument(
"--n_test_episodes",
type=int,
default=100,
help="The number of episodes to test for. (default: 100)",
)
parser.add_argument(
"--test_every",
type=int,
default=100,
help="During training, test the agent every n episodes. (default: 100)",
)
parser.add_argument(
"--max_steps",
type=int,
default=None,
help="The maximum number of steps per episode before the episode is forced to end. If not provided, defaults to the number of states in the environment. (default: None)",
)
### Agent parameters
parser.add_argument(
"--agent",
type=str,
required=True,
choices=AGENTS_MAP.keys(),
help=f"The agent to use. Currently supports one of: {list(AGENTS_MAP.keys())}",
)
parser.add_argument(
"--gamma",
type=float,
default=0.99,
help="The value for the discount factor to use. (default: 0.99)",
)
parser.add_argument(
"--epsilon",
type=float,
default=0.4,
help="The value for the epsilon-greedy policy to use. (default: 0.4)",
)
parser.add_argument(
"--type",
type=str,
choices=["onpolicy", "offpolicy"],
default="onpolicy",
help="The type of update to use. Only supported by Monte-Carlo agent. (default: onpolicy)",
)
### Environment parameters
parser.add_argument(
"--env",
type=str,
default="CliffWalking-v0",
choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"],
help="The Gymnasium environment to use. (default: CliffWalking-v0)",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="The seed to use when generating the FrozenLake environment. If not provided, a random seed is used. (default: None)",
)
parser.add_argument(
"--size",
type=int,
default=8,
help="The size to use when generating the FrozenLake environment. (default: 8)",
)
parser.add_argument(
"--render_mode",
type=str,
default=None,
help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)",
)
# Logging and saving parameters
parser.add_argument(
"--save_dir",
type=str,
default="policies",
help="The directory to save the policy to. (default: policies)",
)
parser.add_argument(
"--no_save",
action="store_true",
help="Use this flag to disable saving the policy.",
)
parser.add_argument(
"--run_name_suffix",
type=str,
default=None,
help="Run name suffix for logging and policy checkpointing. (default: None)",
)
parser.add_argument(
"--wandb_project",
type=str,
default=None,
help="WandB project name for logging. If not provided, no logging is done. (default: None)",
)
parser.add_argument(
"--wandb_job_type",
type=str,
default="train",
help="WandB job type for logging. (default: train)",
)
args = parser.parse_args()
print(vars(args))
agent = load_agent(
args.agent if args.test is None else args.test, **dict(args._get_kwargs())
)
agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}"
if args.run_name_suffix is not None:
agent.run_name += f"+{args.run_name_suffix}"
try:
if args.train:
# Log to WandB
if args.wandb_project is not None:
wandb.init(
project=args.wandb_project,
name=agent.run_name,
group=args.agent,
job_type=args.wandb_job_type,
config=dict(args._get_kwargs()),
)
agent.train(
n_train_episodes=args.n_train_episodes,
test_every=args.test_every,
n_test_episodes=args.n_test_episodes,
max_steps=args.max_steps,
log_wandb=args.wandb_project is not None,
save_best=True,
save_best_dir=args.save_dir,
)
if not args.no_save:
agent.save_policy(save_dir=args.save_dir)
elif args.test is not None:
agent.test(
n_test_episodes=args.n_test_episodes,
max_steps=args.max_steps,
)
else:
print("ERROR: Please provide either --train or --test.")
except KeyboardInterrupt:
print("Exiting...")
if __name__ == "__main__":
main()