Spaces:
Sleeping
Sleeping
File size: 2,134 Bytes
1c007bb c412087 3cafd2c 1c007bb 3bc694c c412087 1c007bb c412087 3cafd2c 1c007bb 38ae408 3bc694c 38ae408 1c007bb 3bc694c 3cafd2c a202b6d 3bc694c a202b6d c412087 3cafd2c a202b6d 3cafd2c c10a05f a202b6d 3cafd2c a202b6d 3cafd2c a202b6d 3cafd2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import torch
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download
from wordle_env.state import update_from_mask
from .net import GreedyNet
from .utils import v_wrap
load_dotenv()
MODEL_NAME = os.getenv("RS_WORDLE_MODEL_NAME")
HF_MODEL_REPO_NAME = os.getenv("HF_MODEL_REPO_NAME")
MODEL_CHECKPOINT_DIR = "checkpoints"
def get_play_model_path():
return os.path.join(MODEL_CHECKPOINT_DIR, MODEL_NAME)
def get_net(env, pretrained_model_path):
n_s = env.observation_space.shape[0]
n_a = env.action_space.n
words_list = env.words
word_width = len(env.words[0])
net = GreedyNet(n_s, n_a, words_list, word_width)
if not os.path.exists(pretrained_model_path):
pretrained_model_path = hf_hub_download(
HF_MODEL_REPO_NAME, MODEL_NAME, local_dir=MODEL_CHECKPOINT_DIR
)
net.load_state_dict(torch.load(pretrained_model_path))
return net
def get_initial_state(env):
state = env.reset()
return state
def suggest(env, words, states, pretrained_model_path) -> str:
"""
Given a list of words and masks, return the next suggested word
:param agent:
:param env:
:param sequence: History of moves and outcomes until now
:return:
"""
env = env.unwrapped
net = get_net(env, pretrained_model_path)
state = get_initial_state(env)
for word, mask in zip(words, states):
word = word.upper()
mask = list(map(int, mask))
state = update_from_mask(state, word, mask)
return env.words[net.choose_action(v_wrap(state[None, :]))]
def play(env, pretrained_model_path, goal_word=None):
env = env.unwrapped
net = get_net(env, pretrained_model_path)
state = get_initial_state(env)
if goal_word:
env.set_goal_word(goal_word)
outcomes = []
win = False
for i in range(env.max_turns):
action = net.choose_action(v_wrap(state[None, :]))
state, reward, done, _ = env.step(action)
outcomes.append(env.words[action])
if done:
if reward > 0:
win = True
break
return win, outcomes
|