File size: 2,406 Bytes
614d543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import logging

from src.state import STATE
from src.state import tokenizer
from src.state import model
from src.text import get_text

logger = logging.getLogger(__name__)

all_tokens = tokenizer.encode(get_text())


def get_model_predictions(input_text: str) -> torch.Tensor:
    """
    Returns the indices as a torch tensor of the top 3 predicted tokens.
    """
    inputs = tokenizer(input_text, return_tensors="pt")

    with torch.no_grad():
        logits = model(**inputs).logits

    last_token = logits[0, -1]
    top_3 = torch.topk(last_token, 3).indices.tolist()
    return top_3


def handle_guess(text: str) -> str:
    """
    * 
    * Retreives model predictions and compares the top 3 predicted tokens 
    """
    current_tokens = all_tokens[:STATE.current_word_index]
    current_text = tokenizer.decode(current_tokens)
    player_guesses = ""
    lm_guesses = ""
    remaining_attempts = 3

    if not text:
        return (
            current_text,
            player_guesses,
            lm_guesses,
            remaining_attempts
        )

    next_token = all_tokens[STATE.current_word_index]
    predicted_token_start = tokenizer.encode(text, add_special_tokens=False)[0]
    predicted_token_whitespace = tokenizer.encode(". " + text, add_special_tokens=False)[1]
    logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
    logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))

    guess_is_correct = next_token in (predicted_token_start, predicted_token_whitespace)

    if guess_is_correct or remaining_attempts == 0:
        STATE.current_word_index += 1
        current_tokens = all_tokens[:STATE.current_word_index]
        remaining_attempts = 3
        STATE.player_guesses = []
        STATE.lm_guesses = []
    else:
        remaining_attempts -= 1
        STATE.player_guesses.append(tokenizer.decode([predicted_token_whitespace]))

    # FIXME: unoptimized, computing all three every time
    STATE.lm_guesses = get_model_predictions(tokenizer.decode(current_tokens))[:3-remaining_attempts]
    logger.debug(f"lm_guesses: {tokenizer.decode(lm_guesses)}")

    player_guesses = "\n".join(STATE.player_guesses)
    current_text = tokenizer.decode(current_tokens)

    return (
        current_text,
        player_guesses,
        lm_guesses,
        remaining_attempts
    )