atrost commited on
Commit
52dc0ae
·
1 Parent(s): f116fee

add local perplexity

Browse files
Files changed (1) hide show
  1. local_perplexity.py +151 -58
local_perplexity.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -11,85 +11,178 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- """TODO: Add a description here."""
15
 
16
- import evaluate
17
  import datasets
 
 
 
 
 
 
 
18
 
19
 
20
- # TODO: Add BibTeX citation
21
  _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
26
- }
27
  """
28
 
29
- # TODO: Add description of the module here
30
- _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
 
32
  """
33
 
34
-
35
- # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
 
 
 
 
 
 
 
 
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
 
 
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
-
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
- >>> print(results)
53
- {'accuracy': 1.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
-
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
- class local_perplexity(evaluate.Metric):
62
- """TODO: Short description of my evaluation module."""
63
-
64
  def _info(self):
65
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
66
  return evaluate.MetricInfo(
67
- # This is the description that will appear on the modules page.
68
  module_type="metric",
69
  description=_DESCRIPTION,
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
- # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
- # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
- def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
88
-
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ """Perplexity modified to use local models."""
15
 
 
16
  import datasets
17
+ import numpy as np
18
+ import torch
19
+ from torch.nn import CrossEntropyLoss
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+
22
+ import evaluate
23
+ from evaluate import logging
24
 
25
 
 
26
  _CITATION = """\
 
 
 
 
 
27
  """
28
 
29
+ _DESCRIPTION = """
30
+ Perplexity (PPL) is one of the most common metrics for evaluating language models.
31
+ It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`.
32
+ For more information, see https://huggingface.co/docs/transformers/perplexity
33
  """
34
 
 
 
35
  _KWARGS_DESCRIPTION = """
 
36
  Args:
37
+ model_id (str): model used for calculating Perplexity
38
+ NOTE: Perplexity can only be calculated for causal language models.
39
+ This includes models such as gpt2, causal variations of bert,
40
+ causal versions of t5, and more (the full list can be found
41
+ in the AutoModelForCausalLM documentation here:
42
+ https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM )
43
+ predictions (list of str): input text, each separate text snippet
44
+ is one list entry.
45
+ batch_size (int): the batch size to run texts through the model. Defaults to 16.
46
+ add_start_token (bool): whether to add the start token to the texts,
47
+ so the perplexity can include the probability of the first word. Defaults to True.
48
+ device (str): device to run on, defaults to 'cuda' when available
49
  Returns:
50
+ perplexity: dictionary containing the perplexity scores for the texts
51
+ in the input list, as well as the mean perplexity. If one of the input texts is
52
+ longer than the max input length of the model, then it is truncated to the
53
+ max length for the perplexity computation.
54
  Examples:
55
+ Example 1:
56
+ >>> perplexity = evaluate.load("perplexity", module_type="metric")
57
+ >>> input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
58
+ >>> results = perplexity.compute(model_id='gpt2',
59
+ ... add_start_token=False,
60
+ ... predictions=input_texts) # doctest:+ELLIPSIS
61
+ >>> print(list(results.keys()))
62
+ ['perplexities', 'mean_perplexity']
63
+ >>> print(round(results["mean_perplexity"], 0))
64
+ 647.0
65
+ >>> print(round(results["perplexities"][0], 0))
66
+ 32.0
67
+ Example 2:
68
+ >>> from datasets import load_dataset
69
+ >>> perplexity = evaluate.load("perplexity", module_type="metric")
70
+ >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP
71
+ >>> input_texts = [s for s in input_texts if s!='']
72
+ >>> results = perplexity.compute(model_id='gpt2',
73
+ ... predictions=input_texts)
74
+ >>> print(list(results.keys()))
75
+ ['perplexities', 'mean_perplexity']
76
+ >>> print(round(results["mean_perplexity"], 2)) # doctest: +SKIP
77
+ 576.76
78
+ >>> print(round(results["perplexities"][0], 2)) # doctest: +SKIP
79
+ 889.28
80
  """
81
 
 
 
 
82
 
83
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
84
+ class Perplexity(evaluate.Metric):
 
 
85
  def _info(self):
 
86
  return evaluate.MetricInfo(
 
87
  module_type="metric",
88
  description=_DESCRIPTION,
89
  citation=_CITATION,
90
  inputs_description=_KWARGS_DESCRIPTION,
91
+ features=datasets.Features(
92
+ {
93
+ "predictions": datasets.Value("string"),
94
+ }
95
+ ),
96
+ reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
 
 
 
 
97
  )
98
 
99
+ def _compute(
100
+ self, predictions, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None, local_file_only: bool = False
101
+ ):
102
+
103
+ if device is not None:
104
+ assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
105
+ if device == "gpu":
106
+ device = "cuda"
107
+ else:
108
+ device = "cuda" if torch.cuda.is_available() else "cpu"
109
+
110
+ model = AutoModelForCausalLM.from_pretrained(model_id, local_files_only=local_file_only)
111
+ model = model.to(device)
112
+
113
+ tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=local_file_only)
114
+
115
+ # if batch_size > 1 (which generally leads to padding being required), and
116
+ # if there is not an already assigned pad_token, assign an existing
117
+ # special token to also be the padding token
118
+ if tokenizer.pad_token is None and batch_size > 1:
119
+ existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
120
+ # check that the model already has at least one special token defined
121
+ assert (
122
+ len(existing_special_tokens) > 0
123
+ ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
124
+ # assign one of the special tokens to also be the pad token
125
+ tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
126
+
127
+ if add_start_token and max_length:
128
+ # leave room for <BOS> token to be added:
129
+ assert (
130
+ tokenizer.bos_token is not None
131
+ ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
132
+ max_tokenized_len = max_length - 1
133
+ else:
134
+ max_tokenized_len = max_length
135
+
136
+ encodings = tokenizer(
137
+ predictions,
138
+ add_special_tokens=False,
139
+ padding=True,
140
+ truncation=True if max_tokenized_len else False,
141
+ max_length=max_tokenized_len,
142
+ return_tensors="pt",
143
+ return_attention_mask=True,
144
+ ).to(device)
145
+
146
+ encoded_texts = encodings["input_ids"]
147
+ attn_masks = encodings["attention_mask"]
148
+
149
+ # check that each input is long enough:
150
+ if add_start_token:
151
+ assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
152
+ else:
153
+ assert torch.all(
154
+ torch.ge(attn_masks.sum(1), 2)
155
+ ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
156
+
157
+ ppls = []
158
+ loss_fct = CrossEntropyLoss(reduction="none")
159
+
160
+ for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
161
+ end_index = min(start_index + batch_size, len(encoded_texts))
162
+ encoded_batch = encoded_texts[start_index:end_index]
163
+ attn_mask = attn_masks[start_index:end_index]
164
+
165
+ if add_start_token:
166
+ bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
167
+ encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
168
+ attn_mask = torch.cat(
169
+ [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
170
+ )
171
+
172
+ labels = encoded_batch
173
+
174
+ with torch.no_grad():
175
+ out_logits = model(encoded_batch, attention_mask=attn_mask).logits
176
+
177
+ shift_logits = out_logits[..., :-1, :].contiguous()
178
+ shift_labels = labels[..., 1:].contiguous()
179
+ shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
180
+
181
+ perplexity_batch = torch.exp(
182
+ (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
183
+ / shift_attention_mask_batch.sum(1)
184
+ )
185
+
186
+ ppls += perplexity_batch.tolist()
187
+
188
+ return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}