# Copyright 2024 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Evaluation script for SQuAD version 2.0. The functions are copied and modified from https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an additional na_prob.json file is provided. This file is expected to map question ID's to the model's predicted probability that a question is unanswerable. """ import collections import re import string from absl import logging def _make_qid_to_has_ans(dataset): qid_to_has_ans = {} for article in dataset: for p in article['paragraphs']: for qa in p['qas']: qid_to_has_ans[qa['id']] = bool(qa['answers']) return qid_to_has_ans def _normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) return re.sub(regex, ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def _get_tokens(s): if not s: return [] return _normalize_answer(s).split() def _compute_exact(a_gold, a_pred): return int(_normalize_answer(a_gold) == _normalize_answer(a_pred)) def _compute_f1(a_gold, a_pred): """Compute F1-score.""" gold_toks = _get_tokens(a_gold) pred_toks = _get_tokens(a_pred) common = collections.Counter(gold_toks) & collections.Counter(pred_toks) num_same = sum(common.values()) if not gold_toks or not pred_toks: # If either is no-answer, then F1 is 1 if they agree, 0 otherwise return int(gold_toks == pred_toks) if num_same == 0: return 0 precision = 1.0 * num_same / len(pred_toks) recall = 1.0 * num_same / len(gold_toks) f1 = (2 * precision * recall) / (precision + recall) return f1 def _get_raw_scores(dataset, predictions): """Compute raw scores.""" exact_scores = {} f1_scores = {} for article in dataset: for p in article['paragraphs']: for qa in p['qas']: qid = qa['id'] gold_answers = [a['text'] for a in qa['answers'] if _normalize_answer(a['text'])] if not gold_answers: # For unanswerable questions, only correct answer is empty string gold_answers = [''] if qid not in predictions: logging.error('Missing prediction for %s', qid) continue a_pred = predictions[qid] # Take max over all gold answers exact_scores[qid] = max(_compute_exact(a, a_pred) for a in gold_answers) f1_scores[qid] = max(_compute_f1(a, a_pred) for a in gold_answers) return exact_scores, f1_scores def _apply_no_ans_threshold( scores, na_probs, qid_to_has_ans, na_prob_thresh=1.0): new_scores = {} for qid, s in scores.items(): pred_na = na_probs[qid] > na_prob_thresh if pred_na: new_scores[qid] = float(not qid_to_has_ans[qid]) else: new_scores[qid] = s return new_scores def _make_eval_dict(exact_scores, f1_scores, qid_list=None): """Make evaluation result dictionary.""" if not qid_list: total = len(exact_scores) return collections.OrderedDict([ ('exact', 100.0 * sum(exact_scores.values()) / total), ('f1', 100.0 * sum(f1_scores.values()) / total), ('total', total), ]) else: total = len(qid_list) return collections.OrderedDict([ ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), ('total', total), ]) def _merge_eval(main_eval, new_eval, prefix): for k in new_eval: main_eval['%s_%s' % (prefix, k)] = new_eval[k] def _make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans): """Make evaluation dictionary containing average recision recall.""" qid_list = sorted(na_probs, key=lambda k: na_probs[k]) true_pos = 0.0 cur_p = 1.0 cur_r = 0.0 precisions = [1.0] recalls = [0.0] avg_prec = 0.0 for i, qid in enumerate(qid_list): if qid_to_has_ans[qid]: true_pos += scores[qid] cur_p = true_pos / float(i+1) cur_r = true_pos / float(num_true_pos) if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: # i.e., if we can put a threshold after this point avg_prec += cur_p * (cur_r - recalls[-1]) precisions.append(cur_p) recalls.append(cur_r) return {'ap': 100.0 * avg_prec} def _run_precision_recall_analysis( main_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans): """Run precision recall analysis and return result dictionary.""" num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) if num_true_pos == 0: return pr_exact = _make_precision_recall_eval( exact_raw, na_probs, num_true_pos, qid_to_has_ans) pr_f1 = _make_precision_recall_eval( f1_raw, na_probs, num_true_pos, qid_to_has_ans) oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} pr_oracle = _make_precision_recall_eval( oracle_scores, na_probs, num_true_pos, qid_to_has_ans) _merge_eval(main_eval, pr_exact, 'pr_exact') _merge_eval(main_eval, pr_f1, 'pr_f1') _merge_eval(main_eval, pr_oracle, 'pr_oracle') def _find_best_thresh(predictions, scores, na_probs, qid_to_has_ans): """Find the best threshold for no answer probability.""" num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) cur_score = num_no_ans best_score = cur_score best_thresh = 0.0 qid_list = sorted(na_probs, key=lambda k: na_probs[k]) for qid in qid_list: if qid not in scores: continue if qid_to_has_ans[qid]: diff = scores[qid] else: if predictions[qid]: diff = -1 else: diff = 0 cur_score += diff if cur_score > best_score: best_score = cur_score best_thresh = na_probs[qid] return 100.0 * best_score / len(scores), best_thresh def _find_all_best_thresh( main_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans): best_exact, exact_thresh = _find_best_thresh( predictions, exact_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh = _find_best_thresh( predictions, f1_raw, na_probs, qid_to_has_ans) main_eval['final_exact'] = best_exact main_eval['final_exact_thresh'] = exact_thresh main_eval['final_f1'] = best_f1 main_eval['final_f1_thresh'] = f1_thresh def evaluate(dataset, predictions, na_probs=None): """Evaluate prediction results.""" new_orig_data = [] for article in dataset: for p in article['paragraphs']: for qa in p['qas']: if qa['id'] in predictions: new_para = {'qas': [qa]} new_article = {'paragraphs': [new_para]} new_orig_data.append(new_article) dataset = new_orig_data if na_probs is None: na_probs = {k: 0.0 for k in predictions} qid_to_has_ans = _make_qid_to_has_ans(dataset) # maps qid to True/False has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] exact_raw, f1_raw = _get_raw_scores(dataset, predictions) exact_thresh = _apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans) f1_thresh = _apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans) out_eval = _make_eval_dict(exact_thresh, f1_thresh) if has_ans_qids: has_ans_eval = _make_eval_dict( exact_thresh, f1_thresh, qid_list=has_ans_qids) _merge_eval(out_eval, has_ans_eval, 'HasAns') if no_ans_qids: no_ans_eval = _make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) _merge_eval(out_eval, no_ans_eval, 'NoAns') _find_all_best_thresh( out_eval, predictions, exact_raw, f1_raw, na_probs, qid_to_has_ans) _run_precision_recall_analysis( out_eval, exact_raw, f1_raw, na_probs, qid_to_has_ans) return out_eval