File size: 2,427 Bytes
cf64284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional

import torch
from torch import Tensor
from tqdm import tqdm
from transformers.modeling_outputs import CausalLMOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer


class Perplexity:
    """
    Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
    This is a custom variant that doesn't re-tokenize the input or re-load the model.
    """

    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        max_seq_len: int,
        stride: int = 512,
    ) -> None:
        self.max_seq_len = max_seq_len
        self.stride = stride
        self.model = model
        self.tokenizer = tokenizer
        self.device = model.device
        self.name = "perplexity"

    def _feature_names(self) -> List[str]:
        return ["references"]

    def compute(
        self,
        references: Optional[List[str]] = None,
    ) -> Dict[str, float]:
        """
        Compute perplexity in a fixed length sliding window across the sequence.
        """
        assert references is not None, "Missing parameter: references"

        references_tokenized = self.tokenizer(
            references, return_tensors="pt", padding=True, truncation=True
        )
        input_ids: Tensor = references_tokenized["input_ids"]  # type: ignore
        input_ids = input_ids.to(self.device)

        sequence_length = input_ids.size(1)

        losses = []
        prev_end_loc = 0
        for begin_loc in tqdm(range(0, sequence_length, self.stride)):
            end_loc = min(begin_loc + self.max_seq_len, sequence_length)
            trg_len = end_loc - prev_end_loc
            input_ids_slice = input_ids[:, begin_loc:end_loc]
            labels_slice = input_ids_slice.clone()
            labels_slice[:, :-trg_len] = -100

            with torch.no_grad():
                outputs: CausalLMOutput = self.model(
                    input_ids=input_ids_slice, labels=labels_slice
                )

            losses.append(outputs.loss)

            prev_end_loc = end_loc
            if end_loc == sequence_length:
                break

        perplexity = torch.exp(torch.stack(losses).mean()).item()

        return {
            "score": perplexity,
        }