import os import torch from dotenv import load_dotenv from huggingface_hub import hf_hub_download from wordle_env.state import update_from_mask from .net import GreedyNet from .utils import v_wrap load_dotenv() MODEL_NAME = os.getenv("RS_WORDLE_MODEL_NAME") HF_MODEL_REPO_NAME = os.getenv("HF_MODEL_REPO_NAME") MODEL_CHECKPOINT_DIR = "checkpoints" def get_play_model_path(): return os.path.join(MODEL_CHECKPOINT_DIR, MODEL_NAME) def get_net(env, pretrained_model_path): n_s = env.observation_space.shape[0] n_a = env.action_space.n words_list = env.words word_width = len(env.words[0]) net = GreedyNet(n_s, n_a, words_list, word_width) if not os.path.exists(pretrained_model_path): pretrained_model_path = hf_hub_download( HF_MODEL_REPO_NAME, MODEL_NAME, local_dir=MODEL_CHECKPOINT_DIR ) net.load_state_dict(torch.load(pretrained_model_path)) return net def get_initial_state(env): state = env.reset() return state def suggest(env, words, states, pretrained_model_path) -> str: """ Given a list of words and masks, return the next suggested word :param agent: :param env: :param sequence: History of moves and outcomes until now :return: """ env = env.unwrapped net = get_net(env, pretrained_model_path) state = get_initial_state(env) for word, mask in zip(words, states): word = word.upper() mask = list(map(int, mask)) state = update_from_mask(state, word, mask) return env.words[net.choose_action(v_wrap(state[None, :]))] def play(env, pretrained_model_path, goal_word=None): env = env.unwrapped net = get_net(env, pretrained_model_path) state = get_initial_state(env) if goal_word: env.set_goal_word(goal_word) outcomes = [] win = False for i in range(env.max_turns): action = net.choose_action(v_wrap(state[None, :])) state, reward, done, _ = env.step(action) outcomes.append(env.words[action]) if done: if reward > 0: win = True break return win, outcomes