File size: 2,134 Bytes
1c007bb
c412087
3cafd2c
1c007bb
3bc694c
c412087
1c007bb
c412087
3cafd2c
 
1c007bb
38ae408
 
3bc694c
 
 
38ae408
1c007bb
 
3bc694c
3cafd2c
 
a202b6d
 
 
 
 
 
3bc694c
 
 
 
a202b6d
 
 
 
 
 
 
 
 
c412087
3cafd2c
 
 
 
 
 
 
 
 
a202b6d
 
3cafd2c
 
 
 
 
 
 
c10a05f
a202b6d
 
 
 
 
3cafd2c
 
 
 
 
a202b6d
3cafd2c
a202b6d
3cafd2c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os

import torch
from dotenv import load_dotenv
from huggingface_hub import hf_hub_download

from wordle_env.state import update_from_mask

from .net import GreedyNet
from .utils import v_wrap

load_dotenv()

MODEL_NAME = os.getenv("RS_WORDLE_MODEL_NAME")
HF_MODEL_REPO_NAME = os.getenv("HF_MODEL_REPO_NAME")
MODEL_CHECKPOINT_DIR = "checkpoints"


def get_play_model_path():
    return os.path.join(MODEL_CHECKPOINT_DIR, MODEL_NAME)


def get_net(env, pretrained_model_path):
    n_s = env.observation_space.shape[0]
    n_a = env.action_space.n
    words_list = env.words
    word_width = len(env.words[0])
    net = GreedyNet(n_s, n_a, words_list, word_width)
    if not os.path.exists(pretrained_model_path):
        pretrained_model_path = hf_hub_download(
            HF_MODEL_REPO_NAME, MODEL_NAME, local_dir=MODEL_CHECKPOINT_DIR
        )
    net.load_state_dict(torch.load(pretrained_model_path))
    return net


def get_initial_state(env):
    state = env.reset()
    return state


def suggest(env, words, states, pretrained_model_path) -> str:
    """
    Given a list of words and masks, return the next suggested word

    :param agent:
    :param env:
    :param sequence: History of moves and outcomes until now
    :return:
    """
    env = env.unwrapped
    net = get_net(env, pretrained_model_path)
    state = get_initial_state(env)
    for word, mask in zip(words, states):
        word = word.upper()
        mask = list(map(int, mask))
        state = update_from_mask(state, word, mask)
    return env.words[net.choose_action(v_wrap(state[None, :]))]


def play(env, pretrained_model_path, goal_word=None):
    env = env.unwrapped
    net = get_net(env, pretrained_model_path)
    state = get_initial_state(env)
    if goal_word:
        env.set_goal_word(goal_word)
    outcomes = []
    win = False
    for i in range(env.max_turns):
        action = net.choose_action(v_wrap(state[None, :]))
        state, reward, done, _ = env.step(action)
        outcomes.append(env.words[action])
        if done:
            if reward > 0:
                win = True
            break
    return win, outcomes