Spaces:
Sleeping
Sleeping
File size: 3,977 Bytes
4c2a92d 570282c 44db2f9 a777e34 3cafd2c 1c007bb 44db2f9 350e00d 4c2a92d fa34b1d 23fd1ff 4c2a92d 3cafd2c 1bd428f 62c6c3b 44db2f9 350e00d 1bd428f 4c2a92d 3cafd2c 4c2a92d fa34b1d 23fd1ff fa34b1d 4c2a92d 3cafd2c 4c2a92d 1c007bb 4c2a92d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#!/usr/bin/env python3
import argparse
import os
import time
import matplotlib.pyplot as plt
from a3c.train import train
from a3c.eval import evaluate, evaluate_checkpoints
from a3c.play import suggest
from wordle_env.wordle import get_env
def training_mode(args, env, model_checkpoint_dir):
max_ep = args.games
start_time = time.time()
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, args.seed, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
print("--- %.0f seconds ---" % (time.time() - start_time))
print_results(global_ep, win_ep, res)
evaluate(gnet, env)
def evaluation_mode(args, env, model_checkpoint_dir):
print("Evaluation mode")
results = evaluate_checkpoints(model_checkpoint_dir, env)
print(results)
def play_mode(args, env, model_checkpoint_dir):
print("Play mode")
words = [ word.strip() for word in args.words.split(',') ]
states = [ state.strip() for state in args.states.split(',') ]
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
word = suggest(env, words, states, pretrained_model_path)
print(word)
def print_results(global_ep, win_ep, res):
print("Jugadas:", global_ep.value)
print("Ganadas:", win_ep.value)
plt.plot(res)
plt.ylabel('Moving average ep reward')
plt.xlabel('Step')
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"enviroment", help="Enviroment (type of wordle game) used for training, example: WordleEnvFull-v0")
parser.add_argument(
"--models_dir", help="Directory where models are saved (default=checkpoints)", default='checkpoints')
subparsers = parser.add_subparsers(help='sub-command help')
parser_train = subparsers.add_parser(
'train', help='Train a model from scratch or train from pretrained model')
parser_train.add_argument(
"--games", "-g", help="Number of games to train", type=int, required=True)
parser_train.add_argument(
"--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
parser_train.add_argument(
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
parser_train.add_argument(
"--seed", help="Seed used for random numbers generation", type=int, default=100)
parser_train.add_argument(
"--save", '-s', help="Save instances of the model while training", action='store_true')
parser_train.add_argument(
"--min_reward", help="The minimun global reward value achieved for saving the model", type=float, default=9.9)
parser_train.add_argument(
"--every_n_save", help="Check every n training steps to save the model", type=int, default=100)
parser_train.set_defaults(func=training_mode)
parser_eval = subparsers.add_parser(
'eval', help='Evaluate saved models for the enviroment')
parser_eval.set_defaults(func=evaluation_mode)
parser_play = subparsers.add_parser(
'play', help='Give the model a word and the state result and the model will try to predict the goal word')
parser_play.add_argument(
"--words", "-w", help="List of words played in the wordle game", required=True)
parser_play.add_argument(
"--states", "-st", help="List of states returned by playing each of the words", required=True)
parser_play.add_argument(
"--model_name", "-m", help="Name of the pretrained model file thich will play the game", required=True)
parser_play.set_defaults(func=play_mode)
args = parser.parse_args()
env_id = args.enviroment
env = get_env(env_id)
model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
args.func(args, env, model_checkpoint_dir)
|