Spaces:
Runtime error
Runtime error
import torch | |
import logging | |
from src.state import STATE | |
from src.state import tokenizer | |
from src.state import model | |
from src.text import get_text | |
logger = logging.getLogger(__name__) | |
all_tokens = tokenizer.encode(get_text()) | |
def get_model_predictions(input_text: str) -> torch.Tensor: | |
""" | |
Returns the indices as a torch tensor of the top 3 predicted tokens. | |
""" | |
inputs = tokenizer(input_text, return_tensors="pt") | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
last_token = logits[0, -1] | |
top_3 = torch.topk(last_token, 3).indices.tolist() | |
return top_3 | |
def handle_guess(text: str) -> str: | |
""" | |
* | |
* Retreives model predictions and compares the top 3 predicted tokens | |
""" | |
current_tokens = all_tokens[:STATE.current_word_index] | |
current_text = tokenizer.decode(current_tokens) | |
player_guesses = "" | |
lm_guesses = "" | |
remaining_attempts = 3 | |
if not text: | |
return ( | |
current_text, | |
player_guesses, | |
lm_guesses, | |
remaining_attempts | |
) | |
next_token = all_tokens[STATE.current_word_index] | |
predicted_token_start = tokenizer.encode(text, add_special_tokens=False)[0] | |
predicted_token_whitespace = tokenizer.encode(". " + text, add_special_tokens=False)[1] | |
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token]))) | |
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace])) | |
guess_is_correct = next_token in (predicted_token_start, predicted_token_whitespace) | |
if guess_is_correct or remaining_attempts == 0: | |
STATE.current_word_index += 1 | |
current_tokens = all_tokens[:STATE.current_word_index] | |
remaining_attempts = 3 | |
STATE.player_guesses = [] | |
STATE.lm_guesses = [] | |
else: | |
remaining_attempts -= 1 | |
STATE.player_guesses.append(tokenizer.decode([predicted_token_whitespace])) | |
# FIXME: unoptimized, computing all three every time | |
STATE.lm_guesses = get_model_predictions(tokenizer.decode(current_tokens))[:3-remaining_attempts] | |
logger.debug(f"lm_guesses: {tokenizer.decode(lm_guesses)}") | |
player_guesses = "\n".join(STATE.player_guesses) | |
current_text = tokenizer.decode(current_tokens) | |
return ( | |
current_text, | |
player_guesses, | |
lm_guesses, | |
remaining_attempts | |
) |