yu-val-weiss commited on
Commit
04c0ccd
·
1 Parent(s): 8f3cd77

add by phenomenon

Browse files
Files changed (1) hide show
  1. blimp.py +30 -11
blimp.py CHANGED
@@ -13,6 +13,8 @@
13
  # limitations under the License.
14
  """Blimp Metric."""
15
 
 
 
16
  import datasets
17
  import evaluate
18
  import torch
@@ -21,7 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
21
 
22
  datasets.logging.set_verbosity_error()
23
 
24
- BLIMP_PHENOMENA = [
25
  "adjunct_island",
26
  "anaphor_gender_agreement",
27
  "anaphor_number_agreement",
@@ -191,30 +193,37 @@ class Blimp(evaluate.Metric):
191
  # assign one of the special tokens to also be the pad token
192
  tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
193
 
194
- print("PAD", tokenizer.pad_token_id)
195
-
196
  run_all = len(predictions) == 0 or predictions[0] == "*"
197
  blimp_sets = (
198
- BLIMP_PHENOMENA
199
  if run_all
200
- else [p for p in BLIMP_PHENOMENA if p.lower() in predictions]
201
  )
202
 
203
  assert len(blimp_sets) > 0, "no valid phenomena selected"
204
 
205
  results = {}
 
206
 
207
- for phenomenon in logging.tqdm(blimp_sets, desc="Evaluating phenomena..."):
208
- dataset = datasets.load_dataset("nyu-mll/blimp", phenomenon)["train"]
209
 
210
  # Prepare batches of good and bad sentences
211
 
 
 
212
  sents = [(x["sentence_good"], x["sentence_bad"]) for x in dataset]
213
  good_sents, bad_sents = zip(*sents[: min(1000, samples_per_set)])
214
 
215
  # Get probabilities in batches
216
  good_probs = get_batch_probabilities(
217
- model, tokenizer, good_sents, device, batch_size, phenomenon
 
 
 
 
 
 
218
  )
219
  bad_probs = get_batch_probabilities(
220
  model,
@@ -222,19 +231,29 @@ class Blimp(evaluate.Metric):
222
  bad_sents,
223
  device,
224
  batch_size,
225
- phenomenon,
226
  sent_type="bad",
227
  )
228
 
229
  # Compare probabilities
230
  correct = sum(g > b for g, b in zip(good_probs, bad_probs))
231
  accuracy = correct / len(good_probs)
232
- results[phenomenon] = accuracy
 
 
233
 
 
 
 
 
234
  # Calculate overall accuracy
235
  overall_accuracy = sum(results.values()) / len(results)
236
 
237
- return {"phenomenon_accuracies": results, "overall_accuracy": overall_accuracy}
 
 
 
 
238
 
239
 
240
  def get_batch_probabilities(
 
13
  # limitations under the License.
14
  """Blimp Metric."""
15
 
16
+ from collections import defaultdict
17
+
18
  import datasets
19
  import evaluate
20
  import torch
 
23
 
24
  datasets.logging.set_verbosity_error()
25
 
26
+ BLIMP_UIDS = [
27
  "adjunct_island",
28
  "anaphor_gender_agreement",
29
  "anaphor_number_agreement",
 
193
  # assign one of the special tokens to also be the pad token
194
  tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})
195
 
 
 
196
  run_all = len(predictions) == 0 or predictions[0] == "*"
197
  blimp_sets = (
198
+ BLIMP_UIDS
199
  if run_all
200
+ else [p for p in BLIMP_UIDS if p.lower() in predictions]
201
  )
202
 
203
  assert len(blimp_sets) > 0, "no valid phenomena selected"
204
 
205
  results = {}
206
+ phenom_results = defaultdict(list)
207
 
208
+ for category in logging.tqdm(blimp_sets, desc="Evaluating phenomena..."):
209
+ dataset = datasets.load_dataset("nyu-mll/blimp", category)["train"]
210
 
211
  # Prepare batches of good and bad sentences
212
 
213
+ phenom = dataset[0]["linguistics_term"]
214
+
215
  sents = [(x["sentence_good"], x["sentence_bad"]) for x in dataset]
216
  good_sents, bad_sents = zip(*sents[: min(1000, samples_per_set)])
217
 
218
  # Get probabilities in batches
219
  good_probs = get_batch_probabilities(
220
+ model,
221
+ tokenizer,
222
+ good_sents,
223
+ device,
224
+ batch_size,
225
+ category,
226
+ sent_type="good",
227
  )
228
  bad_probs = get_batch_probabilities(
229
  model,
 
231
  bad_sents,
232
  device,
233
  batch_size,
234
+ category,
235
  sent_type="bad",
236
  )
237
 
238
  # Compare probabilities
239
  correct = sum(g > b for g, b in zip(good_probs, bad_probs))
240
  accuracy = correct / len(good_probs)
241
+ results[category] = accuracy
242
+
243
+ phenom_results[phenom].append(accuracy)
244
 
245
+ phenom_term_averages = {
246
+ term: sum(accuracies) / len(accuracies)
247
+ for term, accuracies in phenom_results.items()
248
+ }
249
  # Calculate overall accuracy
250
  overall_accuracy = sum(results.values()) / len(results)
251
 
252
+ return {
253
+ "by_uid": results,
254
+ "accuracy": overall_accuracy,
255
+ "by_phenomenon": phenom_term_averages,
256
+ }
257
 
258
 
259
  def get_batch_probabilities(