Spaces:
Runtime error
Runtime error
File size: 8,625 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import collections
import re
import string
from absl import logging
def _make_qid_to_has_ans(dataset):
qid_to_has_ans = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid_to_has_ans[qa['id']] = bool(qa['answers'])
return qid_to_has_ans
def _normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
return re.sub(regex, ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def _get_tokens(s):
if not s: return []
return _normalize_answer(s).split()
def _compute_exact(a_gold, a_pred):
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred))
def _compute_f1(a_gold, a_pred):
"""Compute F1-score."""
gold_toks = _get_tokens(a_gold)
pred_toks = _get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if not gold_toks or not pred_toks:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def _get_raw_scores(dataset, predictions):
"""Compute raw scores."""
exact_scores = {}
f1_scores = {}
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
qid = qa['id']
gold_answers = [a['text'] for a in qa['answers']
if _normalize_answer(a['text'])]
if not gold_answers:
# For unanswerable questions, only correct answer is empty string
gold_answers = ['']
if qid not in predictions:
logging.error('Missing prediction for %s', qid)
continue
a_pred = predictions[qid]
# Take max over all gold answers
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers)
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers)
return exact_scores, f1_scores
def _apply_no_ans_threshold(
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0):
new_scores = {}
for qid, s in scores.items():
pred_na = na_probs[qid] > na_prob_thresh
if pred_na:
new_scores[qid] = float(not qid_to_has_ans[qid])
else:
new_scores[qid] = s
return new_scores
def _make_eval_dict(exact_scores, f1_scores, qid_list=None):
"""Make evaluation result dictionary."""
if not qid_list:
total = len(exact_scores)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores.values()) / total),
('f1', 100.0 * sum(f1_scores.values()) / total),
('total', total),
])
else:
total = len(qid_list)
return collections.OrderedDict([
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total),
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total),
('total', total),
])
def _merge_eval(main_eval, new_eval, prefix):
for k in new_eval:
main_eval['%s_%s' % (prefix, k)] = new_eval[k]
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans):
"""Make evaluation dictionary containing average recision recall."""
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
true_pos = 0.0
cur_p = 1.0
cur_r = 0.0
precisions = [1.0]
recalls = [0.0]
avg_prec = 0.0
for i, qid in enumerate(qid_list):
if qid_to_has_ans[qid]:
true_pos += scores[qid]
cur_p = true_pos / float(i+1)
cur_r = true_pos / float(num_true_pos)
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]:
# i.e., if we can put a threshold after this point
avg_prec += cur_p * (cur_r - recalls[-1])
precisions.append(cur_p)
recalls.append(cur_r)
return {'ap': 100.0 * avg_prec}
def _run_precision_recall_analysis(
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans):
"""Run precision recall analysis and return result dictionary."""
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v)
if num_true_pos == 0:
return
pr_exact = _make_precision_recall_eval(
exact_raw, na_probs, num_true_pos, qid_to_has_ans)
pr_f1 = _make_precision_recall_eval(
f1_raw, na_probs, num_true_pos, qid_to_has_ans)
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()}
pr_oracle = _make_precision_recall_eval(
oracle_scores, na_probs, num_true_pos, qid_to_has_ans)
_merge_eval(main_eval, pr_exact, 'pr_exact')
_merge_eval(main_eval, pr_f1, 'pr_f1')
_merge_eval(main_eval, pr_oracle, 'pr_oracle')
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans):
"""Find the best threshold for no answer probability."""
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
cur_score = num_no_ans
best_score = cur_score
best_thresh = 0.0
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
for qid in qid_list:
if qid not in scores: continue
if qid_to_has_ans[qid]:
diff = scores[qid]
else:
if predictions[qid]:
diff = -1
else:
diff = 0
cur_score += diff
if cur_score > best_score:
best_score = cur_score
best_thresh = na_probs[qid]
return 100.0 * best_score / len(scores), best_thresh
def _find_all_best_thresh(
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans):
best_exact, exact_thresh = _find_best_thresh(
predictions, exact_raw, na_probs, qid_to_has_ans)
best_f1, f1_thresh = _find_best_thresh(
predictions, f1_raw, na_probs, qid_to_has_ans)
main_eval['final_exact'] = best_exact
main_eval['final_exact_thresh'] = exact_thresh
main_eval['final_f1'] = best_f1
main_eval['final_f1_thresh'] = f1_thresh
def evaluate(dataset, predictions, na_probs=None):
"""Evaluate prediction results."""
new_orig_data = []
for article in dataset:
for p in article['paragraphs']:
for qa in p['qas']:
if qa['id'] in predictions:
new_para = {'qas': [qa]}
new_article = {'paragraphs': [new_para]}
new_orig_data.append(new_article)
dataset = new_orig_data
if na_probs is None:
na_probs = {k: 0.0 for k in predictions}
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
exact_raw, f1_raw = _get_raw_scores(dataset, predictions)
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans)
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans)
out_eval = _make_eval_dict(exact_thresh, f1_thresh)
if has_ans_qids:
has_ans_eval = _make_eval_dict(
exact_thresh, f1_thresh, qid_list=has_ans_qids)
_merge_eval(out_eval, has_ans_eval, 'HasAns')
if no_ans_qids:
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids)
_merge_eval(out_eval, no_ans_eval, 'NoAns')
_find_all_best_thresh(
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans)
_run_precision_recall_analysis(
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans)
return out_eval
|