Spaces:
Sleeping
Sleeping
Add play mode
Browse filesFrom a word a state and a saved model the model returns the probable goal word
- a3c/eval.py +1 -15
- a3c/play.py +48 -0
- main.py +21 -1
a3c/eval.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import torch
|
3 |
|
4 |
from .net import GreedyNet
|
|
|
5 |
from .utils import v_wrap
|
6 |
|
7 |
|
@@ -38,18 +39,3 @@ def evaluate(net, env):
|
|
38 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
39 |
f"{n_guesses / N} including losses.")
|
40 |
return n_wins/N*100, n_win_guesses/n_wins
|
41 |
-
|
42 |
-
|
43 |
-
def play(net, env):
|
44 |
-
state = env.reset()
|
45 |
-
outcomes = []
|
46 |
-
win = False
|
47 |
-
for i in range(env.max_turns):
|
48 |
-
action = net.choose_action(v_wrap(state[None, :]))
|
49 |
-
state, reward, done, _ = env.step(action)
|
50 |
-
outcomes.append((env.words[action], reward))
|
51 |
-
if done:
|
52 |
-
if reward >= 0:
|
53 |
-
win = True
|
54 |
-
break
|
55 |
-
return win, outcomes
|
|
|
2 |
import torch
|
3 |
|
4 |
from .net import GreedyNet
|
5 |
+
from .play import play
|
6 |
from .utils import v_wrap
|
7 |
|
8 |
|
|
|
39 |
print(f"Evaluation complete, won {n_wins/N*100}% and took {n_win_guesses/n_wins} guesses per win, "
|
40 |
f"{n_guesses / N} including losses.")
|
41 |
return n_wins/N*100, n_win_guesses/n_wins
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a3c/play.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .net import GreedyNet
|
3 |
+
from .utils import v_wrap
|
4 |
+
from wordle_env.state import update_from_mask
|
5 |
+
|
6 |
+
|
7 |
+
def suggest(
|
8 |
+
env,
|
9 |
+
words,
|
10 |
+
states,
|
11 |
+
pretrained_model_path
|
12 |
+
) -> str:
|
13 |
+
"""
|
14 |
+
Given a list of words and masks, return the next suggested word
|
15 |
+
|
16 |
+
:param agent:
|
17 |
+
:param env:
|
18 |
+
:param sequence: History of moves and outcomes until now
|
19 |
+
:return:
|
20 |
+
"""
|
21 |
+
n_s = env.observation_space.shape[0]
|
22 |
+
n_a = env.action_space.n
|
23 |
+
env = env.unwrapped
|
24 |
+
state = env.reset()
|
25 |
+
words_list = env.words
|
26 |
+
word_width = len(env.words[0])
|
27 |
+
net = GreedyNet(n_s, n_a, words_list, word_width)
|
28 |
+
net.load_state_dict(torch.load(pretrained_model_path))
|
29 |
+
for word, mask in zip(words, states):
|
30 |
+
word = word.upper()
|
31 |
+
mask = list(map(int, mask))
|
32 |
+
state = update_from_mask(state, word, mask)
|
33 |
+
return env.words[net.choose_action(v_wrap(state[None, :]))]
|
34 |
+
|
35 |
+
|
36 |
+
def play(net, env):
|
37 |
+
state = env.reset()
|
38 |
+
outcomes = []
|
39 |
+
win = False
|
40 |
+
for i in range(env.max_turns):
|
41 |
+
action = net.choose_action(v_wrap(state[None, :]))
|
42 |
+
state, reward, done, _ = env.step(action)
|
43 |
+
outcomes.append((env.words[action], reward))
|
44 |
+
if done:
|
45 |
+
if reward >= 0:
|
46 |
+
win = True
|
47 |
+
break
|
48 |
+
return win, outcomes
|
main.py
CHANGED
@@ -8,6 +8,7 @@ import time
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
from a3c.train import train
|
10 |
from a3c.eval import evaluate, evaluate_checkpoints
|
|
|
11 |
from wordle_env.wordle import WordleEnvBase
|
12 |
|
13 |
|
@@ -27,6 +28,15 @@ def evaluation_mode(args, env, model_checkpoint_dir):
|
|
27 |
print(results)
|
28 |
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def print_results(global_ep, win_ep, res):
|
31 |
print("Jugadas:", global_ep.value)
|
32 |
print("Ganadas:", win_ep.value)
|
@@ -49,7 +59,7 @@ if __name__ == "__main__":
|
|
49 |
parser_train.add_argument(
|
50 |
"--games", "-g", help="Number of games to train", type=int, required=True)
|
51 |
parser_train.add_argument(
|
52 |
-
"--model_name", "-
|
53 |
parser_train.add_argument(
|
54 |
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
|
55 |
parser_train.add_argument(
|
@@ -64,6 +74,16 @@ if __name__ == "__main__":
|
|
64 |
'eval', help='Evaluate saved models for the enviroment')
|
65 |
parser_eval.set_defaults(func=evaluation_mode)
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
args = parser.parse_args()
|
68 |
env_id = args.enviroment
|
69 |
env = gym.make(env_id)
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
from a3c.train import train
|
10 |
from a3c.eval import evaluate, evaluate_checkpoints
|
11 |
+
from a3c.play import suggest
|
12 |
from wordle_env.wordle import WordleEnvBase
|
13 |
|
14 |
|
|
|
28 |
print(results)
|
29 |
|
30 |
|
31 |
+
def play_mode(args, env, model_checkpoint_dir):
|
32 |
+
print("Play mode")
|
33 |
+
words = [ word.strip() for word in args.words.split(',') ]
|
34 |
+
states = [ state.strip() for state in args.states.split(',') ]
|
35 |
+
pretrained_model_path = os.path.join(model_checkpoint_dir, args.model_name)
|
36 |
+
word = suggest(env, words, states, pretrained_model_path)
|
37 |
+
print(word)
|
38 |
+
|
39 |
+
|
40 |
def print_results(global_ep, win_ep, res):
|
41 |
print("Jugadas:", global_ep.value)
|
42 |
print("Ganadas:", win_ep.value)
|
|
|
59 |
parser_train.add_argument(
|
60 |
"--games", "-g", help="Number of games to train", type=int, required=True)
|
61 |
parser_train.add_argument(
|
62 |
+
"--model_name", "-m", help="If want to train from a pretrained model, the name of the pretrained model file")
|
63 |
parser_train.add_argument(
|
64 |
"--gamma", help="Gamma hyperparameter (discount factor) value", type=float, default=0.)
|
65 |
parser_train.add_argument(
|
|
|
74 |
'eval', help='Evaluate saved models for the enviroment')
|
75 |
parser_eval.set_defaults(func=evaluation_mode)
|
76 |
|
77 |
+
parser_play = subparsers.add_parser(
|
78 |
+
'play', help='Give the model a word and the state result and the model will try to predict the goal word')
|
79 |
+
parser_play.add_argument(
|
80 |
+
"--words", "-w", help="List of words played in the wordle game", required=True)
|
81 |
+
parser_play.add_argument(
|
82 |
+
"--states", "-st", help="List of states returned by playing each of the words", required=True)
|
83 |
+
parser_play.add_argument(
|
84 |
+
"--model_name", "-m", help="Name of the pretrained model file thich will play the game", required=True)
|
85 |
+
parser_play.set_defaults(func=play_mode)
|
86 |
+
|
87 |
args = parser.parse_args()
|
88 |
env_id = args.enviroment
|
89 |
env = gym.make(env_id)
|