File size: 3,977 Bytes
4c2a92d
 
 
 
570282c
44db2f9
a777e34
 
3cafd2c
1c007bb
44db2f9
350e00d
4c2a92d
 
 
fa34b1d
23fd1ff
4c2a92d
 
 
 
 
 
 
 
 
 
 
3cafd2c
 
 
 
 
 
 
 
 
1bd428f
62c6c3b
 
44db2f9
 
 
350e00d
1bd428f
 
 
4c2a92d
 
 
 
 
 
 
 
 
 
 
 
3cafd2c
4c2a92d
fa34b1d
23fd1ff
 
fa34b1d
 
 
 
 
 
4c2a92d
 
 
 
 
 
3cafd2c
 
 
 
 
 
 
 
 
 
4c2a92d
 
1c007bb
4c2a92d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3

import argparse
import os
import time
import matplotlib.pyplot as plt
from a3c.train import train
from a3c.eval import evaluate, evaluate_checkpoints
from a3c.play import suggest
from wordle_env.wordle import get_env


def training_mode(args, env, model_checkpoint_dir):
    max_ep = args.games
    start_time = time.time()
    pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name) if args.model_name else args.model_name
    global_ep, win_ep, gnet, res = train(env, max_ep, model_checkpoint_dir, args.gamma, args.seed, pretrained_model_path, args.save, args.min_reward, args.every_n_save)
    print("--- %.0f seconds ---" % (time.time() - start_time))
    print_results(global_ep, win_ep, res)
    evaluate(gnet, env)


def evaluation_mode(args, env, model_checkpoint_dir):
    print("Evaluation mode")
    results = evaluate_checkpoints(model_checkpoint_dir, env)
    print(results)


def play_mode(args, env, model_checkpoint_dir):
    print("Play mode")
    words = [ word.strip() for word in args.words.split(',') ]
    states = [ state.strip() for state in args.states.split(',') ]
    pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
    word = suggest(env, words, states, pretrained_model_path)
    print(word)


def print_results(global_ep, win_ep, res):
    print("Jugadas:", global_ep.value)
    print("Ganadas:", win_ep.value)
    plt.plot(res)
    plt.ylabel('Moving average ep reward')
    plt.xlabel('Step')
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "enviroment", help="Enviroment (type of wordle game) used for training, example: WordleEnvFull-v0")
    parser.add_argument(
        "--models_dir", help="Directory where models are saved (default=checkpoints)", default='checkpoints')
    subparsers = parser.add_subparsers(help='sub-command help')

    parser_train = subparsers.add_parser(
        'train', help='Train a model from scratch or train from pretrained model')
    parser_train.add_argument(
        "--games", "-g", help="Number of games to train", type=int, required=True)
    parser_train.add_argument(
        "--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
    parser_train.add_argument(
        "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
    parser_train.add_argument(
        "--seed", help="Seed used for random numbers generation", type=int, default=100)
    parser_train.add_argument(
        "--save", '-s', help="Save instances of the model while training", action='store_true')
    parser_train.add_argument(
        "--min_reward", help="The minimun global reward value achieved for saving the model", type=float, default=9.9)
    parser_train.add_argument(
        "--every_n_save", help="Check every n training steps to save the model", type=int, default=100)
    parser_train.set_defaults(func=training_mode)

    parser_eval = subparsers.add_parser(
        'eval', help='Evaluate saved models for the enviroment')
    parser_eval.set_defaults(func=evaluation_mode)

    parser_play = subparsers.add_parser(
        'play', help='Give the model a word and the state result and the model will try to predict the goal word')
    parser_play.add_argument(
        "--words", "-w", help="List of words played in the wordle game", required=True)
    parser_play.add_argument(
        "--states", "-st", help="List of states returned by playing each of the words", required=True)
    parser_play.add_argument(
        "--model_name", "-m", help="Name of the pretrained model file thich will play the game", required=True)
    parser_play.set_defaults(func=play_mode)

    args = parser.parse_args()
    env_id = args.enviroment
    env = get_env(env_id)
    model_checkpoint_dir = os.path.join(args.models_dir, env.unwrapped.spec.id)
    args.func(args, env, model_checkpoint_dir)