import logging from transformers import PreTrainedTokenizer from src import shared from src.constants import MAX_ATTEMPTS from src.constants import STARTING_INDEX from src.params import ReducerParams from src.shared import token_id_predictions logger = logging.getLogger(__name__) def get_current_lm_guess_str(word_number, remaining_attempts): # FIXME: indexerror guess_list = token_id_predictions[(STARTING_INDEX + word_number) - 1][1] guess_list = [shared.tokenizer.decode(i) for i in guess_list] censored_list = ["*****"] * MAX_ATTEMPTS for i in range(MAX_ATTEMPTS - remaining_attempts): censored_list[i] = guess_list[i] return "\n".join(censored_list) def get_current_prompt_text(word_number): # FIXME: indexerror return shared.tokenizer.decode(shared.all_tokens[: STARTING_INDEX + word_number]) def get_start_and_whitespace_tokens( word: str, tokenizer: PreTrainedTokenizer, ) -> tuple[int]: """ It is difficult to tell whether """ predicted_token_start = tokenizer.encode(word, add_special_tokens=False)[0] predicted_token_whitespace = tokenizer.encode(". " + word, add_special_tokens=False)[1] return predicted_token_start, predicted_token_whitespace def lm_is_correct(params: ReducerParams) -> bool: # NOTE: out of range if remaining attempts is 0 if params.remaining_attempts < 1: return False idx = MAX_ATTEMPTS - params.remaining_attempts # FIXME: indexerror current_guess = token_id_predictions[STARTING_INDEX + params.word_number - 1][1][idx] current_target = token_id_predictions[STARTING_INDEX + params.word_number - 1][0] return current_guess == current_target def guess_is_correct(params: ReducerParams, tokenizer: PreTrainedTokenizer) -> bool: """ We check if the predicted token or a corresponding one with a leading whitespace matches that of the next token """ # FIXME: handle indexerro print(STARTING_INDEX + params.word_number) current_target = shared.all_tokens[STARTING_INDEX + params.word_number] logger.debug("Next token: '{}'".format(tokenizer.convert_ids_to_tokens([current_target]))) predicted_token_start, predicted_token_whitespace = get_start_and_whitespace_tokens(params.guess_field, tokenizer) logger.debug(tokenizer.convert_ids_to_tokens([predicted_token_start, predicted_token_whitespace])) return current_target in (predicted_token_start, predicted_token_whitespace)