Spaces:
Sleeping
Sleeping
import argparse | |
import wandb | |
from utils import AGENTS_MAP, load_agent | |
def main(): | |
parser = argparse.ArgumentParser() | |
### Train/Test parameters | |
parser.add_argument( | |
"--train", | |
action="store_true", | |
help="Use this flag to train the agent.", | |
) | |
parser.add_argument( | |
"--test", | |
type=str, | |
default=None, | |
help="Use this flag to test the agent. Provide the path to the policy file.", | |
) | |
parser.add_argument( | |
"--n_train_episodes", | |
type=int, | |
default=2500, | |
help="The number of episodes to train for. (default: 2500)", | |
) | |
parser.add_argument( | |
"--n_test_episodes", | |
type=int, | |
default=100, | |
help="The number of episodes to test for. (default: 100)", | |
) | |
parser.add_argument( | |
"--test_every", | |
type=int, | |
default=100, | |
help="During training, test the agent every n episodes. (default: 100)", | |
) | |
parser.add_argument( | |
"--max_steps", | |
type=int, | |
default=None, | |
help="The maximum number of steps per episode before the episode is forced to end. If not provided, defaults to the number of states in the environment. (default: None)", | |
) | |
### Agent parameters | |
parser.add_argument( | |
"--agent", | |
type=str, | |
required=True, | |
choices=AGENTS_MAP.keys(), | |
help=f"The agent to use. Currently supports one of: {list(AGENTS_MAP.keys())}", | |
) | |
parser.add_argument( | |
"--gamma", | |
type=float, | |
default=0.99, | |
help="The value for the discount factor to use. (default: 0.99)", | |
) | |
parser.add_argument( | |
"--epsilon", | |
type=float, | |
default=0.4, | |
help="The value for the epsilon-greedy policy to use. (default: 0.4)", | |
) | |
parser.add_argument( | |
"--type", | |
type=str, | |
choices=["onpolicy", "offpolicy"], | |
default="onpolicy", | |
help="The type of update to use. Only supported by Monte-Carlo agent. (default: onpolicy)", | |
) | |
### Environment parameters | |
parser.add_argument( | |
"--env", | |
type=str, | |
default="CliffWalking-v0", | |
choices=["CliffWalking-v0", "FrozenLake-v1", "Taxi-v3"], | |
help="The Gymnasium environment to use. (default: CliffWalking-v0)", | |
) | |
parser.add_argument( | |
"--seed", | |
type=int, | |
default=None, | |
help="The seed to use when generating the FrozenLake environment. If not provided, a random seed is used. (default: None)", | |
) | |
parser.add_argument( | |
"--size", | |
type=int, | |
default=8, | |
help="The size to use when generating the FrozenLake environment. (default: 8)", | |
) | |
parser.add_argument( | |
"--render_mode", | |
type=str, | |
default=None, | |
help="Render mode passed to the gym.make() function. Use 'human' to render the environment. (default: None)", | |
) | |
# Logging and saving parameters | |
parser.add_argument( | |
"--save_dir", | |
type=str, | |
default="policies", | |
help="The directory to save the policy to. (default: policies)", | |
) | |
parser.add_argument( | |
"--no_save", | |
action="store_true", | |
help="Use this flag to disable saving the policy.", | |
) | |
parser.add_argument( | |
"--run_name_suffix", | |
type=str, | |
default=None, | |
help="Run name suffix for logging and policy checkpointing. (default: None)", | |
) | |
parser.add_argument( | |
"--wandb_project", | |
type=str, | |
default=None, | |
help="WandB project name for logging. If not provided, no logging is done. (default: None)", | |
) | |
parser.add_argument( | |
"--wandb_job_type", | |
type=str, | |
default="train", | |
help="WandB job type for logging. (default: train)", | |
) | |
args = parser.parse_args() | |
print(vars(args)) | |
agent = load_agent( | |
args.agent if args.test is None else args.test, **dict(args._get_kwargs()) | |
) | |
agent.run_name += f"_e{args.n_train_episodes}_s{args.max_steps}" | |
if args.run_name_suffix is not None: | |
agent.run_name += f"+{args.run_name_suffix}" | |
try: | |
if args.train: | |
# Log to WandB | |
if args.wandb_project is not None: | |
wandb.init( | |
project=args.wandb_project, | |
name=agent.run_name, | |
group=args.agent, | |
job_type=args.wandb_job_type, | |
config=dict(args._get_kwargs()), | |
) | |
agent.train( | |
n_train_episodes=args.n_train_episodes, | |
test_every=args.test_every, | |
n_test_episodes=args.n_test_episodes, | |
max_steps=args.max_steps, | |
log_wandb=args.wandb_project is not None, | |
save_best=True, | |
save_best_dir=args.save_dir, | |
) | |
if not args.no_save: | |
agent.save_policy(save_dir=args.save_dir) | |
elif args.test is not None: | |
agent.test( | |
n_test_episodes=args.n_test_episodes, | |
max_steps=args.max_steps, | |
) | |
else: | |
print("ERROR: Please provide either --train or --test.") | |
except KeyboardInterrupt: | |
print("Exiting...") | |
if __name__ == "__main__": | |
main() | |