marksverdhei
WIP: Add attempt count
614d543
raw
history blame
2.41 kB
import torch
import logging
from src.state import STATE
from src.state import tokenizer
from src.state import model
from src.text import get_text
logger = logging.getLogger(__name__)
all_tokens = tokenizer.encode(get_text())
def get_model_predictions(input_text: str) -> torch.Tensor:
"""
Returns the indices as a torch tensor of the top 3 predicted tokens.
"""
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
last_token = logits[0, -1]
top_3 = torch.topk(last_token, 3).indices.tolist()
return top_3
def handle_guess(text: str) -> str:
"""
*
* Retreives model predictions and compares the top 3 predicted tokens
"""
current_tokens = all_tokens[:STATE.current_word_index]
current_text = tokenizer.decode(current_tokens)
player_guesses = ""
lm_guesses = ""
remaining_attempts = 3
if not text:
return (
current_text,
player_guesses,
lm_guesses,
remaining_attempts
)
next_token = all_tokens[STATE.current_word_index]
predicted_token_start = tokenizer.encode(text, add_special_tokens=False)[0]
predicted_token_whitespace = tokenizer.encode(". " + text, add_special_tokens=False)[1]
logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([next_token])))
logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace]))
guess_is_correct = next_token in (predicted_token_start, predicted_token_whitespace)
if guess_is_correct or remaining_attempts == 0:
STATE.current_word_index += 1
current_tokens = all_tokens[:STATE.current_word_index]
remaining_attempts = 3
STATE.player_guesses = []
STATE.lm_guesses = []
else:
remaining_attempts -= 1
STATE.player_guesses.append(tokenizer.decode([predicted_token_whitespace]))
# FIXME: unoptimized, computing all three every time
STATE.lm_guesses = get_model_predictions(tokenizer.decode(current_tokens))[:3-remaining_attempts]
logger.debug(f"lm_guesses: {tokenizer.decode(lm_guesses)}")
player_guesses = "\n".join(STATE.player_guesses)
current_text = tokenizer.decode(current_tokens)
return (
current_text,
player_guesses,
lm_guesses,
remaining_attempts
)