Spaces:
Running
Running
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Evaluation script for SQuAD version 2.0. | |
The functions are copied and modified from | |
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py | |
In addition to basic functionality, we also compute additional statistics and | |
plot precision-recall curves if an additional na_prob.json file is provided. | |
This file is expected to map question ID's to the model's predicted probability | |
that a question is unanswerable. | |
""" | |
import collections | |
import re | |
import string | |
from absl import logging | |
def _make_qid_to_has_ans(dataset): | |
qid_to_has_ans = {} | |
for article in dataset: | |
for p in article['paragraphs']: | |
for qa in p['qas']: | |
qid_to_has_ans[qa['id']] = bool(qa['answers']) | |
return qid_to_has_ans | |
def _normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) | |
return re.sub(regex, ' ', text) | |
def white_space_fix(text): | |
return ' '.join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return ''.join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def _get_tokens(s): | |
if not s: return [] | |
return _normalize_answer(s).split() | |
def _compute_exact(a_gold, a_pred): | |
return int(_normalize_answer(a_gold) == _normalize_answer(a_pred)) | |
def _compute_f1(a_gold, a_pred): | |
"""Compute F1-score.""" | |
gold_toks = _get_tokens(a_gold) | |
pred_toks = _get_tokens(a_pred) | |
common = collections.Counter(gold_toks) & collections.Counter(pred_toks) | |
num_same = sum(common.values()) | |
if not gold_toks or not pred_toks: | |
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise | |
return int(gold_toks == pred_toks) | |
if num_same == 0: | |
return 0 | |
precision = 1.0 * num_same / len(pred_toks) | |
recall = 1.0 * num_same / len(gold_toks) | |
f1 = (2 * precision * recall) / (precision + recall) | |
return f1 | |
def _get_raw_scores(dataset, predictions): | |
"""Compute raw scores.""" | |
exact_scores = {} | |
f1_scores = {} | |
for article in dataset: | |
for p in article['paragraphs']: | |
for qa in p['qas']: | |
qid = qa['id'] | |
gold_answers = [a['text'] for a in qa['answers'] | |
if _normalize_answer(a['text'])] | |
if not gold_answers: | |
# For unanswerable questions, only correct answer is empty string | |
gold_answers = [''] | |
if qid not in predictions: | |
logging.error('Missing prediction for %s', qid) | |
continue | |
a_pred = predictions[qid] | |
# Take max over all gold answers | |
exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers) | |
f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers) | |
return exact_scores, f1_scores | |
def _apply_no_ans_threshold( | |
scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0): | |
new_scores = {} | |
for qid, s in scores.items(): | |
pred_na = na_probs[qid] > na_prob_thresh | |
if pred_na: | |
new_scores[qid] = float(not qid_to_has_ans[qid]) | |
else: | |
new_scores[qid] = s | |
return new_scores | |
def _make_eval_dict(exact_scores, f1_scores, qid_list=None): | |
"""Make evaluation result dictionary.""" | |
if not qid_list: | |
total = len(exact_scores) | |
return collections.OrderedDict([ | |
('exact', 100.0 * sum(exact_scores.values()) / total), | |
('f1', 100.0 * sum(f1_scores.values()) / total), | |
('total', total), | |
]) | |
else: | |
total = len(qid_list) | |
return collections.OrderedDict([ | |
('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), | |
('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), | |
('total', total), | |
]) | |
def _merge_eval(main_eval, new_eval, prefix): | |
for k in new_eval: | |
main_eval['%s_%s' % (prefix, k)] = new_eval[k] | |
def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans): | |
"""Make evaluation dictionary containing average recision recall.""" | |
qid_list = sorted(na_probs, key=lambda k: na_probs[k]) | |
true_pos = 0.0 | |
cur_p = 1.0 | |
cur_r = 0.0 | |
precisions = [1.0] | |
recalls = [0.0] | |
avg_prec = 0.0 | |
for i, qid in enumerate(qid_list): | |
if qid_to_has_ans[qid]: | |
true_pos += scores[qid] | |
cur_p = true_pos / float(i+1) | |
cur_r = true_pos / float(num_true_pos) | |
if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: | |
# i.e., if we can put a threshold after this point | |
avg_prec += cur_p * (cur_r - recalls[-1]) | |
precisions.append(cur_p) | |
recalls.append(cur_r) | |
return {'ap': 100.0 * avg_prec} | |
def _run_precision_recall_analysis( | |
main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans): | |
"""Run precision recall analysis and return result dictionary.""" | |
num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) | |
if num_true_pos == 0: | |
return | |
pr_exact = _make_precision_recall_eval( | |
exact_raw, na_probs, num_true_pos, qid_to_has_ans) | |
pr_f1 = _make_precision_recall_eval( | |
f1_raw, na_probs, num_true_pos, qid_to_has_ans) | |
oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} | |
pr_oracle = _make_precision_recall_eval( | |
oracle_scores, na_probs, num_true_pos, qid_to_has_ans) | |
_merge_eval(main_eval, pr_exact, 'pr_exact') | |
_merge_eval(main_eval, pr_f1, 'pr_f1') | |
_merge_eval(main_eval, pr_oracle, 'pr_oracle') | |
def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans): | |
"""Find the best threshold for no answer probability.""" | |
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) | |
cur_score = num_no_ans | |
best_score = cur_score | |
best_thresh = 0.0 | |
qid_list = sorted(na_probs, key=lambda k: na_probs[k]) | |
for qid in qid_list: | |
if qid not in scores: continue | |
if qid_to_has_ans[qid]: | |
diff = scores[qid] | |
else: | |
if predictions[qid]: | |
diff = -1 | |
else: | |
diff = 0 | |
cur_score += diff | |
if cur_score > best_score: | |
best_score = cur_score | |
best_thresh = na_probs[qid] | |
return 100.0 * best_score / len(scores), best_thresh | |
def _find_all_best_thresh( | |
main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans): | |
best_exact, exact_thresh = _find_best_thresh( | |
predictions, exact_raw, na_probs, qid_to_has_ans) | |
best_f1, f1_thresh = _find_best_thresh( | |
predictions, f1_raw, na_probs, qid_to_has_ans) | |
main_eval['final_exact'] = best_exact | |
main_eval['final_exact_thresh'] = exact_thresh | |
main_eval['final_f1'] = best_f1 | |
main_eval['final_f1_thresh'] = f1_thresh | |
def evaluate(dataset, predictions, na_probs=None): | |
"""Evaluate prediction results.""" | |
new_orig_data = [] | |
for article in dataset: | |
for p in article['paragraphs']: | |
for qa in p['qas']: | |
if qa['id'] in predictions: | |
new_para = {'qas': [qa]} | |
new_article = {'paragraphs': [new_para]} | |
new_orig_data.append(new_article) | |
dataset = new_orig_data | |
if na_probs is None: | |
na_probs = {k: 0.0 for k in predictions} | |
qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False | |
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] | |
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] | |
exact_raw, f1_raw = _get_raw_scores(dataset, predictions) | |
exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans) | |
f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans) | |
out_eval = _make_eval_dict(exact_thresh, f1_thresh) | |
if has_ans_qids: | |
has_ans_eval = _make_eval_dict( | |
exact_thresh, f1_thresh, qid_list=has_ans_qids) | |
_merge_eval(out_eval, has_ans_eval, 'HasAns') | |
if no_ans_qids: | |
no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) | |
_merge_eval(out_eval, no_ans_eval, 'NoAns') | |
_find_all_best_thresh( | |
out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans) | |
_run_precision_recall_analysis( | |
out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans) | |
return out_eval | |