santit96 commited on
Commit
3cafd2c
·
1 Parent(s): fa34b1d

Add play mode

Browse files

From a word a state and a saved model the model returns the probable goal word

Files changed (3) hide show
  1. a3c/eval.py +1 -15
  2. a3c/play.py +48 -0
  3. main.py +21 -1
a3c/eval.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import torch
3
 
4
  from .net import GreedyNet
 
5
  from .utils import v_wrap
6
 
7
 
@@ -38,18 +39,3 @@ def evaluate(net, env):
38
  print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
39
  f"{n_guesses / N} including losses.")
40
  return n_wins/N*100, n_win_guesses/n_wins
41
-
42
-
43
- def play(net, env):
44
- state = env.reset()
45
- outcomes = []
46
- win = False
47
- for i in range(env.max_turns):
48
- action = net.choose_action(v_wrap(state[None, :]))
49
- state, reward, done, _ = env.step(action)
50
- outcomes.append((env.words[action], reward))
51
- if done:
52
- if reward >= 0:
53
- win = True
54
- break
55
- return win, outcomes
 
2
  import torch
3
 
4
  from .net import GreedyNet
5
+ from .play import play
6
  from .utils import v_wrap
7
 
8
 
 
39
  print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
40
  f"{n_guesses / N} including losses.")
41
  return n_wins/N*100, n_win_guesses/n_wins
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3c/play.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .net import GreedyNet
3
+ from .utils import v_wrap
4
+ from wordle_env.state import update_from_mask
5
+
6
+
7
+ def suggest(
8
+ env,
9
+ words,
10
+ states,
11
+ pretrained_model_path
12
+ ) -> str:
13
+ """
14
+ Given a list of words and masks, return the next suggested word
15
+
16
+ :param agent:
17
+ :param env:
18
+ :param sequence: History of moves and outcomes until now
19
+ :return:
20
+ """
21
+ n_s = env.observation_space.shape[0]
22
+ n_a = env.action_space.n
23
+ env = env.unwrapped
24
+ state = env.reset()
25
+ words_list = env.words
26
+ word_width = len(env.words[0])
27
+ net = GreedyNet(n_s, n_a, words_list, word_width)
28
+ net.load_state_dict(torch.load(pretrained_model_path))
29
+ for word, mask in zip(words, states):
30
+ word = word.upper()
31
+ mask = list(map(int, mask))
32
+ state = update_from_mask(state, word, mask)
33
+ return env.words[net.choose_action(v_wrap(state[None, :]))]
34
+
35
+
36
+ def play(net, env):
37
+ state = env.reset()
38
+ outcomes = []
39
+ win = False
40
+ for i in range(env.max_turns):
41
+ action = net.choose_action(v_wrap(state[None, :]))
42
+ state, reward, done, _ = env.step(action)
43
+ outcomes.append((env.words[action], reward))
44
+ if done:
45
+ if reward >= 0:
46
+ win = True
47
+ break
48
+ return win, outcomes
main.py CHANGED
@@ -8,6 +8,7 @@ import time
8
  import matplotlib.pyplot as plt
9
  from a3c.train import train
10
  from a3c.eval import evaluate, evaluate_checkpoints
 
11
  from wordle_env.wordle import WordleEnvBase
12
 
13
 
@@ -27,6 +28,15 @@ def evaluation_mode(args, env, model_checkpoint_dir):
27
  print(results)
28
 
29
 
 
 
 
 
 
 
 
 
 
30
  def print_results(global_ep, win_ep, res):
31
  print("Jugadas:", global_ep.value)
32
  print("Ganadas:", win_ep.value)
@@ -49,7 +59,7 @@ if __name__ == "__main__":
49
  parser_train.add_argument(
50
  "--games", "-g", help="Number of games to train", type=int, required=True)
51
  parser_train.add_argument(
52
- "--model_name", "-n", help="If want to train from a pretrained model, the name of the pretrained model file")
53
  parser_train.add_argument(
54
  "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
55
  parser_train.add_argument(
@@ -64,6 +74,16 @@ if __name__ == "__main__":
64
  'eval', help='Evaluate saved models for the enviroment')
65
  parser_eval.set_defaults(func=evaluation_mode)
66
 
 
 
 
 
 
 
 
 
 
 
67
  args = parser.parse_args()
68
  env_id = args.enviroment
69
  env = gym.make(env_id)
 
8
  import matplotlib.pyplot as plt
9
  from a3c.train import train
10
  from a3c.eval import evaluate, evaluate_checkpoints
11
+ from a3c.play import suggest
12
  from wordle_env.wordle import WordleEnvBase
13
 
14
 
 
28
  print(results)
29
 
30
 
31
+ def play_mode(args, env, model_checkpoint_dir):
32
+ print("Play mode")
33
+ words = [ word.strip() for word in args.words.split(',') ]
34
+ states = [ state.strip() for state in args.states.split(',') ]
35
+ pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
36
+ word = suggest(env, words, states, pretrained_model_path)
37
+ print(word)
38
+
39
+
40
  def print_results(global_ep, win_ep, res):
41
  print("Jugadas:", global_ep.value)
42
  print("Ganadas:", win_ep.value)
 
59
  parser_train.add_argument(
60
  "--games", "-g", help="Number of games to train", type=int, required=True)
61
  parser_train.add_argument(
62
+ "--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
63
  parser_train.add_argument(
64
  "--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
65
  parser_train.add_argument(
 
74
  'eval', help='Evaluate saved models for the enviroment')
75
  parser_eval.set_defaults(func=evaluation_mode)
76
 
77
+ parser_play = subparsers.add_parser(
78
+ 'play', help='Give the model a word and the state result and the model will try to predict the goal word')
79
+ parser_play.add_argument(
80
+ "--words", "-w", help="List of words played in the wordle game", required=True)
81
+ parser_play.add_argument(
82
+ "--states", "-st", help="List of states returned by playing each of the words", required=True)
83
+ parser_play.add_argument(
84
+ "--model_name", "-m", help="Name of the pretrained model file thich will play the game", required=True)
85
+ parser_play.set_defaults(func=play_mode)
86
+
87
  args = parser.parse_args()
88
  env_id = args.enviroment
89
  env = gym.make(env_id)