Pradeep Kumar commited on
Commit
329d378
·
verified ·
1 Parent(s): 82a3db7

Delete squad_evaluate_v2_0.py

Browse files
Files changed (1) hide show
  1. squad_evaluate_v2_0.py +0 -249
squad_evaluate_v2_0.py DELETED
@@ -1,249 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Evaluation script for SQuAD version 2.0.
16
-
17
- The functions are copied and modified from
18
- https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
19
-
20
- In addition to basic functionality, we also compute additional statistics and
21
- plot precision-recall curves if an additional na_prob.json file is provided.
22
- This file is expected to map question ID's to the model's predicted probability
23
- that a question is unanswerable.
24
- """
25
-
26
- import collections
27
- import re
28
- import string
29
-
30
- from absl import logging
31
-
32
-
33
- def _make_qid_to_has_ans(dataset):
34
- qid_to_has_ans = {}
35
- for article in dataset:
36
- for p in article['paragraphs']:
37
- for qa in p['qas']:
38
- qid_to_has_ans[qa['id']] = bool(qa['answers'])
39
- return qid_to_has_ans
40
-
41
-
42
- def _normalize_answer(s):
43
- """Lower text and remove punctuation, articles and extra whitespace."""
44
- def remove_articles(text):
45
- regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
46
- return re.sub(regex, ' ', text)
47
- def white_space_fix(text):
48
- return ' '.join(text.split())
49
- def remove_punc(text):
50
- exclude = set(string.punctuation)
51
- return ''.join(ch for ch in text if ch not in exclude)
52
- def lower(text):
53
- return text.lower()
54
- return white_space_fix(remove_articles(remove_punc(lower(s))))
55
-
56
-
57
- def _get_tokens(s):
58
- if not s: return []
59
- return _normalize_answer(s).split()
60
-
61
-
62
- def _compute_exact(a_gold, a_pred):
63
- return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
64
-
65
-
66
- def _compute_f1(a_gold, a_pred):
67
- """Compute F1-score."""
68
- gold_toks = _get_tokens(a_gold)
69
- pred_toks = _get_tokens(a_pred)
70
- common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
71
- num_same = sum(common.values())
72
- if not gold_toks or not pred_toks:
73
- # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
74
- return int(gold_toks == pred_toks)
75
- if num_same == 0:
76
- return 0
77
- precision = 1.0 * num_same / len(pred_toks)
78
- recall = 1.0 * num_same / len(gold_toks)
79
- f1 = (2 * precision * recall) / (precision + recall)
80
- return f1
81
-
82
-
83
- def _get_raw_scores(dataset, predictions):
84
- """Compute raw scores."""
85
- exact_scores = {}
86
- f1_scores = {}
87
- for article in dataset:
88
- for p in article['paragraphs']:
89
- for qa in p['qas']:
90
- qid = qa['id']
91
- gold_answers = [a['text'] for a in qa['answers']
92
- if _normalize_answer(a['text'])]
93
- if not gold_answers:
94
- # For unanswerable questions, only correct answer is empty string
95
- gold_answers = ['']
96
- if qid not in predictions:
97
- logging.error('Missing prediction for %s', qid)
98
- continue
99
- a_pred = predictions[qid]
100
- # Take max over all gold answers
101
- exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
102
- f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
103
- return exact_scores, f1_scores
104
-
105
-
106
- def _apply_no_ans_threshold(
107
- scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
108
- new_scores = {}
109
- for qid, s in scores.items():
110
- pred_na = na_probs[qid] > na_prob_thresh
111
- if pred_na:
112
- new_scores[qid] = float(not qid_to_has_ans[qid])
113
- else:
114
- new_scores[qid] = s
115
- return new_scores
116
-
117
-
118
- def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
119
- """Make evaluation result dictionary."""
120
- if not qid_list:
121
- total = len(exact_scores)
122
- return collections.OrderedDict([
123
- ('exact', 100.0 * sum(exact_scores.values()) / total),
124
- ('f1', 100.0 * sum(f1_scores.values()) / total),
125
- ('total', total),
126
- ])
127
- else:
128
- total = len(qid_list)
129
- return collections.OrderedDict([
130
- ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
131
- ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
132
- ('total', total),
133
- ])
134
-
135
-
136
- def _merge_eval(main_eval, new_eval, prefix):
137
- for k in new_eval:
138
- main_eval['%s_%s' % (prefix, k)] = new_eval[k]
139
-
140
-
141
- def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
142
- """Make evaluation dictionary containing average recision recall."""
143
- qid_list = sorted(na_probs, key=lambda k: na_probs[k])
144
- true_pos = 0.0
145
- cur_p = 1.0
146
- cur_r = 0.0
147
- precisions = [1.0]
148
- recalls = [0.0]
149
- avg_prec = 0.0
150
- for i, qid in enumerate(qid_list):
151
- if qid_to_has_ans[qid]:
152
- true_pos += scores[qid]
153
- cur_p = true_pos / float(i+1)
154
- cur_r = true_pos / float(num_true_pos)
155
- if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
156
- # i.e., if we can put a threshold after this point
157
- avg_prec += cur_p * (cur_r - recalls[-1])
158
- precisions.append(cur_p)
159
- recalls.append(cur_r)
160
- return {'ap': 100.0 * avg_prec}
161
-
162
-
163
- def _run_precision_recall_analysis(
164
- main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
165
- """Run precision recall analysis and return result dictionary."""
166
- num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
167
- if num_true_pos == 0:
168
- return
169
- pr_exact = _make_precision_recall_eval(
170
- exact_raw, na_probs, num_true_pos, qid_to_has_ans)
171
- pr_f1 = _make_precision_recall_eval(
172
- f1_raw, na_probs, num_true_pos, qid_to_has_ans)
173
- oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
174
- pr_oracle = _make_precision_recall_eval(
175
- oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
176
- _merge_eval(main_eval, pr_exact, 'pr_exact')
177
- _merge_eval(main_eval, pr_f1, 'pr_f1')
178
- _merge_eval(main_eval, pr_oracle, 'pr_oracle')
179
-
180
-
181
- def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
182
- """Find the best threshold for no answer probability."""
183
- num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
184
- cur_score = num_no_ans
185
- best_score = cur_score
186
- best_thresh = 0.0
187
- qid_list = sorted(na_probs, key=lambda k: na_probs[k])
188
- for qid in qid_list:
189
- if qid not in scores: continue
190
- if qid_to_has_ans[qid]:
191
- diff = scores[qid]
192
- else:
193
- if predictions[qid]:
194
- diff = -1
195
- else:
196
- diff = 0
197
- cur_score += diff
198
- if cur_score > best_score:
199
- best_score = cur_score
200
- best_thresh = na_probs[qid]
201
- return 100.0 * best_score / len(scores), best_thresh
202
-
203
-
204
- def _find_all_best_thresh(
205
- main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
206
- best_exact, exact_thresh = _find_best_thresh(
207
- predictions, exact_raw, na_probs, qid_to_has_ans)
208
- best_f1, f1_thresh = _find_best_thresh(
209
- predictions, f1_raw, na_probs, qid_to_has_ans)
210
- main_eval['final_exact'] = best_exact
211
- main_eval['final_exact_thresh'] = exact_thresh
212
- main_eval['final_f1'] = best_f1
213
- main_eval['final_f1_thresh'] = f1_thresh
214
-
215
-
216
- def evaluate(dataset, predictions, na_probs=None):
217
- """Evaluate prediction results."""
218
- new_orig_data = []
219
- for article in dataset:
220
- for p in article['paragraphs']:
221
- for qa in p['qas']:
222
- if qa['id'] in predictions:
223
- new_para = {'qas': [qa]}
224
- new_article = {'paragraphs': [new_para]}
225
- new_orig_data.append(new_article)
226
- dataset = new_orig_data
227
-
228
- if na_probs is None:
229
- na_probs = {k: 0.0 for k in predictions}
230
- qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
231
- has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
232
- no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
233
- exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
234
- exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
235
- f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
236
- out_eval = _make_eval_dict(exact_thresh, f1_thresh)
237
- if has_ans_qids:
238
- has_ans_eval = _make_eval_dict(
239
- exact_thresh, f1_thresh, qid_list=has_ans_qids)
240
- _merge_eval(out_eval, has_ans_eval, 'HasAns')
241
- if no_ans_qids:
242
- no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
243
- _merge_eval(out_eval, no_ans_eval, 'NoAns')
244
-
245
- _find_all_best_thresh(
246
- out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
247
- _run_precision_recall_analysis(
248
- out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
249
- return out_eval