"""callback to calculate perplexity as an evaluation metric.""" from typing import Dict, List, Optional import torch from torch import Tensor from tqdm import tqdm from transformers.modeling_outputs import CausalLMOutput from transformers.modeling_utils import PreTrainedModel from transformers.tokenization_utils import PreTrainedTokenizer class Perplexity: """ Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity. This is a custom variant that doesn't re-tokenize the input or re-load the model. """ def __init__( self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, max_seq_len: int, stride: int = 512, ) -> None: self.max_seq_len = max_seq_len self.stride = stride self.model = model self.tokenizer = tokenizer self.device = model.device self.name = "perplexity" def _feature_names(self) -> List[str]: return ["references"] def compute( self, references: Optional[List[str]] = None, ) -> Dict[str, float]: """ Compute perplexity in a fixed length sliding window across the sequence. """ assert references is not None, "Missing parameter: references" references_tokenized = self.tokenizer( references, return_tensors="pt", padding=True, truncation=True ) input_ids: Tensor = references_tokenized["input_ids"] # type: ignore input_ids = input_ids.to(self.device) sequence_length = input_ids.size(1) losses = [] prev_end_loc = 0 for begin_loc in tqdm(range(0, sequence_length, self.stride)): end_loc = min(begin_loc + self.max_seq_len, sequence_length) trg_len = end_loc - prev_end_loc input_ids_slice = input_ids[:, begin_loc:end_loc] labels_slice = input_ids_slice.clone() labels_slice[:, :-trg_len] = -100 with torch.no_grad(): outputs: CausalLMOutput = self.model( input_ids=input_ids_slice, labels=labels_slice ) losses.append(outputs.loss) prev_end_loc = end_loc if end_loc == sequence_length: break perplexity = torch.exp(torch.stack(losses).mean()).item() return { "score": perplexity, }