omidf commited on
Commit
163936c
·
1 Parent(s): 32e4ba0

Update compute_score.py

Browse files
Files changed (1) hide show
  1. compute_score.py +38 -9
compute_score.py CHANGED
@@ -25,8 +25,17 @@ def normalize_answer(s):
25
 
26
  return white_space_fix(remove_articles(remove_punc(lower(s))))
27
 
 
 
 
 
 
 
 
 
 
28
 
29
- def f1_score(prediction, ground_truth):
30
  prediction_tokens = normalize_answer(prediction).split()
31
  ground_truth_tokens = normalize_answer(ground_truth).split()
32
  common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
@@ -34,9 +43,20 @@ def f1_score(prediction, ground_truth):
34
  if num_same == 0:
35
  return 0
36
  precision = 1.0 * num_same / len(prediction_tokens)
37
- recall = 1.0 * num_same / len(ground_truth_tokens)
38
- f1 = (2 * precision * recall) / (precision + recall)
39
- return f1
 
 
 
 
 
 
 
 
 
 
 
40
 
41
 
42
  def exact_match_score(prediction, ground_truth):
@@ -52,7 +72,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
52
 
53
 
54
  def compute_score(dataset, predictions):
55
- f1 = exact_match = total = 0
56
  for article in dataset:
57
  for paragraph in article["paragraphs"]:
58
  for qa in paragraph["qas"]:
@@ -64,15 +84,24 @@ def compute_score(dataset, predictions):
64
  ground_truths = list(map(lambda x: x["text"], qa["answers"]))
65
  prediction = predictions[qa["id"]]
66
  exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
67
- f1_temp = metric_max_over_ground_truths(f1_score, prediction, ground_truths)
68
- print(f1_temp)
69
- f1 += f1_temp
 
 
 
 
70
 
 
 
 
71
 
72
  exact_match = 100.0 * exact_match / total
73
  f1 = 100.0 * f1 / total
 
 
74
 
75
- return {"exact_match": exact_match, "f1": f1}
76
 
77
 
78
  if __name__ == "__main__":
 
25
 
26
  return white_space_fix(remove_articles(remove_punc(lower(s))))
27
 
28
+ def recall_score(prediction, ground_truth):
29
+ prediction_tokens = normalize_answer(prediction).split()
30
+ ground_truth_tokens = normalize_answer(ground_truth).split()
31
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
32
+ num_same = sum(common.values())
33
+ if num_same == 0:
34
+ return 0
35
+ recall = 1.0 * num_same / len(ground_truth_tokens)
36
+ return recall
37
 
38
+ def precision_score(prediction, ground_truth):
39
  prediction_tokens = normalize_answer(prediction).split()
40
  ground_truth_tokens = normalize_answer(ground_truth).split()
41
  common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
 
43
  if num_same == 0:
44
  return 0
45
  precision = 1.0 * num_same / len(prediction_tokens)
46
+ return precision
47
+
48
+
49
+ # def f1_score(prediction, ground_truth):
50
+ # prediction_tokens = normalize_answer(prediction).split()
51
+ # ground_truth_tokens = normalize_answer(ground_truth).split()
52
+ # common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
53
+ # num_same = sum(common.values())
54
+ # if num_same == 0:
55
+ # return 0
56
+ # precision = 1.0 * num_same / len(prediction_tokens)
57
+ # recall = 1.0 * num_same / len(ground_truth_tokens)
58
+ # f1 = (2 * precision * recall) / (precision + recall)
59
+ # return f1
60
 
61
 
62
  def exact_match_score(prediction, ground_truth):
 
72
 
73
 
74
  def compute_score(dataset, predictions):
75
+ recall = precision = f1 = exact_match = total = 0
76
  for article in dataset:
77
  for paragraph in article["paragraphs"]:
78
  for qa in paragraph["qas"]:
 
84
  ground_truths = list(map(lambda x: x["text"], qa["answers"]))
85
  prediction = predictions[qa["id"]]
86
  exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
87
+ recall_temp = metric_max_over_ground_truths(recall_score, prediction, ground_truths)
88
+ precision_temp = metric_max_over_ground_truths(precision_score, prediction, ground_truths)
89
+
90
+ if recall_temp + precision_temp == 0:
91
+ f1_temp = 0
92
+ else:
93
+ f1_temp = (2 * precision * recall) / (precision + recall)
94
 
95
+ f1 += f1_temp
96
+ recall += recall_temp
97
+ precision += precision_temp
98
 
99
  exact_match = 100.0 * exact_match / total
100
  f1 = 100.0 * f1 / total
101
+ precision = 100.0 * precision / total
102
+ recall = 100.0 * recall / total
103
 
104
+ return {"exact_match": exact_match, "f1": f1, "recall": recall, "precision": precision}
105
 
106
 
107
  if __name__ == "__main__":