yu-val-weiss commited on
Commit
8f3cd77
·
1 Parent(s): 803da62

Update blimp.py

Browse files
Files changed (1) hide show
  1. blimp.py +166 -79
blimp.py CHANGED
@@ -15,13 +15,83 @@
15
 
16
  import datasets
17
  import evaluate
18
- import numpy as np
19
  import torch
20
  from evaluate import logging
21
- from torch.nn import CrossEntropyLoss
22
  from transformers import AutoModelForCausalLM, AutoTokenizer
23
 
24
- _CITATION = """\
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  @article{warstadt2020blimp,
26
  author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
27
  title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
@@ -37,8 +107,7 @@ _CITATION = """\
37
  }
38
  """
39
 
40
- _DESCRIPTION = """
41
- BLiMP is a challenge set for evaluating what language models (LMs) know about major grammatical phenomena in English.
42
  BLiMP consists of 67 sub-datasets, each containing 1000 minimal pairs isolating specific contrasts in syntax, morphology, or semantics.
43
  The data is automatically generated according to expert-crafted grammars. Aggregate human agreement with the labels is 96.4%.
44
  We use BLiMP to evaluate an n-gram LM, LSTM LM, GPT-2, and Transformer-XL.
@@ -48,9 +117,12 @@ For more info see https://github.com/alexwarstadt/blimp.
48
 
49
  _KWARGS_DESCRIPTION = """
50
  Args:
51
- model_id (str): model used for calculating Blimp
 
52
  batch_size (int): the batch size to run texts through the model. Defaults to 16.
53
- device (str): device to run on, defaults to 'cuda' when available
 
 
54
  Returns:
55
  blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
56
  An LM’s overall accuracy on BLiMP is simply the proportion of the 67,000 minimal pairs in which the model assigns a higher probability to the acceptable sentence.
@@ -60,7 +132,7 @@ Examples:
60
 
61
 
62
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63
- class Perplexity(evaluate.Metric):
64
  def _info(self):
65
  return evaluate.MetricInfo(
66
  module_type="metric",
@@ -80,12 +152,11 @@ class Perplexity(evaluate.Metric):
80
 
81
  def _compute(
82
  self,
83
- predictions,
84
  model_id,
 
85
  batch_size: int = 16,
86
- add_start_token: bool = True,
87
  device=None,
88
- max_length=None,
89
  ):
90
  if device is not None:
91
  assert device in ["gpu", "cpu", "cuda", "mps"], (
@@ -102,6 +173,7 @@ class Perplexity(evaluate.Metric):
102
 
103
  model = AutoModelForCausalLM.from_pretrained(model_id)
104
  model = model.to(device)
 
105
 
106
  tokenizer = AutoTokenizer.from_pretrained(model_id)
107
 
@@ -119,78 +191,93 @@ class Perplexity(evaluate.Metric):
119
  # assign one of the special tokens to also be the pad token
120
  tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
121
 
122
- if add_start_token and max_length:
123
- # leave room for <BOS> token to be added:
124
- assert tokenizer.bos_token is not None, (
125
- "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
126
- )
127
- max_tokenized_len = max_length - 1
128
- else:
129
- max_tokenized_len = max_length
130
-
131
- encodings = tokenizer(
132
- predictions,
133
- add_special_tokens=False,
134
- padding=True,
135
- truncation=True if max_tokenized_len else False,
136
- max_length=max_tokenized_len,
137
- return_tensors="pt",
138
- return_attention_mask=True,
139
- ).to(device)
140
 
141
- encoded_texts = encodings["input_ids"]
142
- attn_masks = encodings["attention_mask"]
143
 
144
- # check that each input is long enough:
145
- if add_start_token:
146
- assert torch.all(torch.ge(attn_masks.sum(1), 1)), (
147
- "Each input text must be at least one token long."
 
 
 
 
 
 
 
 
 
148
  )
149
- else:
150
- assert torch.all(torch.ge(attn_masks.sum(1), 2)), (
151
- "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."
 
 
 
 
 
152
  )
153
 
154
- ppls = []
155
- loss_fct = CrossEntropyLoss(reduction="none")
156
-
157
- for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
158
- end_index = min(start_index + batch_size, len(encoded_texts))
159
- encoded_batch = encoded_texts[start_index:end_index]
160
- attn_mask = attn_masks[start_index:end_index]
161
-
162
- if add_start_token:
163
- bos_tokens_tensor = torch.tensor(
164
- [[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)
165
- ).to(device)
166
- encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
167
- attn_mask = torch.cat(
168
- [
169
- torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(
170
- device
171
- ),
172
- attn_mask,
173
- ],
174
- dim=1,
175
- )
176
-
177
- labels = encoded_batch
178
-
179
- with torch.no_grad():
180
- out_logits = model(encoded_batch, attention_mask=attn_mask).logits
181
-
182
- shift_logits = out_logits[..., :-1, :].contiguous()
183
- shift_labels = labels[..., 1:].contiguous()
184
- shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
185
-
186
- perplexity_batch = torch.exp(
187
- (
188
- loss_fct(shift_logits.transpose(1, 2), shift_labels)
189
- * shift_attention_mask_batch
190
- ).sum(1)
191
- / shift_attention_mask_batch.sum(1)
192
- )
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- ppls += perplexity_batch.tolist()
195
 
196
- return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}
 
15
 
16
  import datasets
17
  import evaluate
 
18
  import torch
19
  from evaluate import logging
 
20
  from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
+ datasets.logging.set_verbosity_error()
23
+
24
+ BLIMP_PHENOMENA = [
25
+ "adjunct_island",
26
+ "anaphor_gender_agreement",
27
+ "anaphor_number_agreement",
28
+ "animate_subject_passive",
29
+ "animate_subject_trans",
30
+ "causative",
31
+ "complex_NP_island",
32
+ "coordinate_structure_constraint_complex_left_branch",
33
+ "coordinate_structure_constraint_object_extraction",
34
+ "determiner_noun_agreement_1",
35
+ "determiner_noun_agreement_2",
36
+ "determiner_noun_agreement_irregular_1",
37
+ "determiner_noun_agreement_irregular_2",
38
+ "determiner_noun_agreement_with_adj_2",
39
+ "determiner_noun_agreement_with_adj_irregular_1",
40
+ "determiner_noun_agreement_with_adj_irregular_2",
41
+ "determiner_noun_agreement_with_adjective_1",
42
+ "distractor_agreement_relational_noun",
43
+ "distractor_agreement_relative_clause",
44
+ "drop_argument",
45
+ "ellipsis_n_bar_1",
46
+ "ellipsis_n_bar_2",
47
+ "existential_there_object_raising",
48
+ "existential_there_quantifiers_1",
49
+ "existential_there_quantifiers_2",
50
+ "existential_there_subject_raising",
51
+ "expletive_it_object_raising",
52
+ "inchoative",
53
+ "intransitive",
54
+ "irregular_past_participle_adjectives",
55
+ "irregular_past_participle_verbs",
56
+ "irregular_plural_subject_verb_agreement_1",
57
+ "irregular_plural_subject_verb_agreement_2",
58
+ "left_branch_island_echo_question",
59
+ "left_branch_island_simple_question",
60
+ "matrix_question_npi_licensor_present",
61
+ "npi_present_1",
62
+ "npi_present_2",
63
+ "only_npi_licensor_present",
64
+ "only_npi_scope",
65
+ "passive_1",
66
+ "passive_2",
67
+ "principle_A_c_command",
68
+ "principle_A_case_1",
69
+ "principle_A_case_2",
70
+ "principle_A_domain_1",
71
+ "principle_A_domain_2",
72
+ "principle_A_domain_3",
73
+ "principle_A_reconstruction",
74
+ "regular_plural_subject_verb_agreement_1",
75
+ "regular_plural_subject_verb_agreement_2",
76
+ "sentential_negation_npi_licensor_present",
77
+ "sentential_negation_npi_scope",
78
+ "sentential_subject_island",
79
+ "superlative_quantifiers_1",
80
+ "superlative_quantifiers_2",
81
+ "tough_vs_raising_1",
82
+ "tough_vs_raising_2",
83
+ "transitive",
84
+ "wh_island",
85
+ "wh_questions_object_gap",
86
+ "wh_questions_subject_gap",
87
+ "wh_questions_subject_gap_long_distance",
88
+ "wh_vs_that_no_gap",
89
+ "wh_vs_that_no_gap_long_distance",
90
+ "wh_vs_that_with_gap",
91
+ "wh_vs_that_with_gap_long_distance",
92
+ ]
93
+
94
+ _CITATION = r"""
95
  @article{warstadt2020blimp,
96
  author = {Warstadt, Alex and Parrish, Alicia and Liu, Haokun and Mohananey, Anhad and Peng, Wei and Wang, Sheng-Fu and Bowman, Samuel R.},
97
  title = {BLiMP: The Benchmark of Linguistic Minimal Pairs for English},
 
107
  }
108
  """
109
 
110
+ _DESCRIPTION = """BLiMP is a challenge set for evaluating what language models (LMs) know about major grammatical phenomena in English.
 
111
  BLiMP consists of 67 sub-datasets, each containing 1000 minimal pairs isolating specific contrasts in syntax, morphology, or semantics.
112
  The data is automatically generated according to expert-crafted grammars. Aggregate human agreement with the labels is 96.4%.
113
  We use BLiMP to evaluate an n-gram LM, LSTM LM, GPT-2, and Transformer-XL.
 
117
 
118
  _KWARGS_DESCRIPTION = """
119
  Args:
120
+ model_id (str): model used for calculating Blimp, NOTE: should be a causal LM model
121
+ predictions (list[str]): names of metrics to run. pass empty list or ["*"] to run all of them
122
  batch_size (int): the batch size to run texts through the model. Defaults to 16.
123
+ device (str): device to run on, defaults to 'cuda' when available.
124
+ samples_per_set (int): the number of samples per phenomenon, defaults to 1_000.
125
+
126
  Returns:
127
  blimp: dictionary containing the blimp scores for each of the 67 sub-datasets, as well as the overall accuracy.
128
  An LM’s overall accuracy on BLiMP is simply the proportion of the 67,000 minimal pairs in which the model assigns a higher probability to the acceptable sentence.
 
132
 
133
 
134
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
135
+ class Blimp(evaluate.Metric):
136
  def _info(self):
137
  return evaluate.MetricInfo(
138
  module_type="metric",
 
152
 
153
  def _compute(
154
  self,
 
155
  model_id,
156
+ predictions=None,
157
  batch_size: int = 16,
 
158
  device=None,
159
+ samples_per_set: int = 1_000,
160
  ):
161
  if device is not None:
162
  assert device in ["gpu", "cpu", "cuda", "mps"], (
 
173
 
174
  model = AutoModelForCausalLM.from_pretrained(model_id)
175
  model = model.to(device)
176
+ model.eval()
177
 
178
  tokenizer = AutoTokenizer.from_pretrained(model_id)
179
 
 
191
  # assign one of the special tokens to also be the pad token
192
  tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
193
 
194
+ print("PAD", tokenizer.pad_token_id)
195
+
196
+ run_all = len(predictions) == 0 or predictions[0] == "*"
197
+ blimp_sets = (
198
+ BLIMP_PHENOMENA
199
+ if run_all
200
+ else [p for p in BLIMP_PHENOMENA if p.lower() in predictions]
201
+ )
 
 
 
 
 
 
 
 
 
 
202
 
203
+ assert len(blimp_sets) > 0, "no valid phenomena selected"
 
204
 
205
+ results = {}
206
+
207
+ for phenomenon in logging.tqdm(blimp_sets, desc="Evaluating phenomena..."):
208
+ dataset = datasets.load_dataset("nyu-mll/blimp", phenomenon)["train"]
209
+
210
+ # Prepare batches of good and bad sentences
211
+
212
+ sents = [(x["sentence_good"], x["sentence_bad"]) for x in dataset]
213
+ good_sents, bad_sents = zip(*sents[: min(1000, samples_per_set)])
214
+
215
+ # Get probabilities in batches
216
+ good_probs = get_batch_probabilities(
217
+ model, tokenizer, good_sents, device, batch_size, phenomenon
218
  )
219
+ bad_probs = get_batch_probabilities(
220
+ model,
221
+ tokenizer,
222
+ bad_sents,
223
+ device,
224
+ batch_size,
225
+ phenomenon,
226
+ sent_type="bad",
227
  )
228
 
229
+ # Compare probabilities
230
+ correct = sum(g > b for g, b in zip(good_probs, bad_probs))
231
+ accuracy = correct / len(good_probs)
232
+ results[phenomenon] = accuracy
233
+
234
+ # Calculate overall accuracy
235
+ overall_accuracy = sum(results.values()) / len(results)
236
+
237
+ return {"phenomenon_accuracies": results, "overall_accuracy": overall_accuracy}
238
+
239
+
240
+ def get_batch_probabilities(
241
+ model,
242
+ tokenizer,
243
+ sentences: list[str],
244
+ device: str,
245
+ batch_size: int,
246
+ phenomenon: str,
247
+ sent_type: str = "good",
248
+ ):
249
+ """Compute log probabilities for a batch of sentences"""
250
+ probs = []
251
+
252
+ for i in logging.tqdm(
253
+ range(0, len(sentences), batch_size),
254
+ desc=f"{phenomenon} - {sent_type} sentences...",
255
+ leave=False,
256
+ ):
257
+ batch = sentences[i : i + batch_size]
258
+ inputs = tokenizer(
259
+ batch, padding=batch_size > 1, return_tensors="pt", truncation=True
260
+ ).to(device)
261
+
262
+ with torch.no_grad():
263
+ outputs = model(**inputs)
264
+
265
+ labels = inputs.input_ids
266
+
267
+ # Compute log probabilities
268
+ log_probs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)
269
+
270
+ # Get probability of each actual token
271
+ token_log_probs = torch.gather(log_probs, 2, labels.unsqueeze(-1)).squeeze(-1)
272
+
273
+ if batch_size > 1:
274
+ # Create attention mask for padding
275
+ mask = (labels != tokenizer.pad_token_id).float()
276
+ token_log_probs *= mask
277
+
278
+ # sum log probabilities
279
+ sequence_log_probs = (token_log_probs).sum(dim=1)
280
 
281
+ probs.extend(sequence_log_probs.cpu().tolist())
282
 
283
+ return probs