|
"""callback to calculate perplexity as an evaluation metric.""" |
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
from tqdm import tqdm |
|
from transformers.modeling_outputs import CausalLMOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
|
|
|
class Perplexity: |
|
""" |
|
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity. |
|
This is a custom variant that doesn't re-tokenize the input or re-load the model. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: PreTrainedModel, |
|
tokenizer: PreTrainedTokenizer, |
|
max_seq_len: int, |
|
stride: int = 512, |
|
) -> None: |
|
self.max_seq_len = max_seq_len |
|
self.stride = stride |
|
self.model = model |
|
self.tokenizer = tokenizer |
|
self.device = model.device |
|
self.name = "perplexity" |
|
|
|
def _feature_names(self) -> List[str]: |
|
return ["references"] |
|
|
|
def compute( |
|
self, |
|
references: Optional[List[str]] = None, |
|
) -> Dict[str, float]: |
|
""" |
|
Compute perplexity in a fixed length sliding window across the sequence. |
|
""" |
|
assert references is not None, "Missing parameter: references" |
|
|
|
references_tokenized = self.tokenizer( |
|
references, return_tensors="pt", padding=True, truncation=True |
|
) |
|
input_ids: Tensor = references_tokenized["input_ids"] |
|
input_ids = input_ids.to(self.device) |
|
|
|
sequence_length = input_ids.size(1) |
|
|
|
losses = [] |
|
prev_end_loc = 0 |
|
for begin_loc in tqdm(range(0, sequence_length, self.stride)): |
|
end_loc = min(begin_loc + self.max_seq_len, sequence_length) |
|
trg_len = end_loc - prev_end_loc |
|
input_ids_slice = input_ids[:, begin_loc:end_loc] |
|
labels_slice = input_ids_slice.clone() |
|
labels_slice[:, :-trg_len] = -100 |
|
|
|
with torch.no_grad(): |
|
outputs: CausalLMOutput = self.model( |
|
input_ids=input_ids_slice, labels=labels_slice |
|
) |
|
|
|
losses.append(outputs.loss) |
|
|
|
prev_end_loc = end_loc |
|
if end_loc == sequence_length: |
|
break |
|
|
|
perplexity = torch.exp(torch.stack(losses).mean()).item() |
|
|
|
return { |
|
"score": perplexity, |
|
} |
|
|