Pradeep Kumar commited on
Commit
e0aefa8
·
verified ·
1 Parent(s): d23a681

Delete squad_evaluate_v1_1.py

Browse files
Files changed (1) hide show
  1. squad_evaluate_v1_1.py +0 -106
squad_evaluate_v1_1.py DELETED
@@ -1,106 +0,0 @@
1
- # Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- """Evaluation of SQuAD predictions (version 1.1).
16
-
17
- The functions are copied from
18
- https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/.
19
-
20
- The SQuAD dataset is described in this paper:
21
- SQuAD: 100,000+ Questions for Machine Comprehension of Text
22
- Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang
23
- https://nlp.stanford.edu/pubs/rajpurkar2016squad.pdf
24
- """
25
-
26
- import collections
27
- import re
28
- import string
29
-
30
- # pylint: disable=g-bad-import-order
31
-
32
- from absl import logging
33
- # pylint: enable=g-bad-import-order
34
-
35
-
36
- def _normalize_answer(s):
37
- """Lowers text and remove punctuation, articles and extra whitespace."""
38
-
39
- def remove_articles(text):
40
- return re.sub(r"\b(a|an|the)\b", " ", text)
41
-
42
- def white_space_fix(text):
43
- return " ".join(text.split())
44
-
45
- def remove_punc(text):
46
- exclude = set(string.punctuation)
47
- return "".join(ch for ch in text if ch not in exclude)
48
-
49
- def lower(text):
50
- return text.lower()
51
-
52
- return white_space_fix(remove_articles(remove_punc(lower(s))))
53
-
54
-
55
- def _f1_score(prediction, ground_truth):
56
- """Computes F1 score by comparing prediction to ground truth."""
57
- prediction_tokens = _normalize_answer(prediction).split()
58
- ground_truth_tokens = _normalize_answer(ground_truth).split()
59
- prediction_counter = collections.Counter(prediction_tokens)
60
- ground_truth_counter = collections.Counter(ground_truth_tokens)
61
- common = prediction_counter & ground_truth_counter
62
- num_same = sum(common.values())
63
- if num_same == 0:
64
- return 0
65
- precision = 1.0 * num_same / len(prediction_tokens)
66
- recall = 1.0 * num_same / len(ground_truth_tokens)
67
- f1 = (2 * precision * recall) / (precision + recall)
68
- return f1
69
-
70
-
71
- def _exact_match_score(prediction, ground_truth):
72
- """Checks if predicted answer exactly matches ground truth answer."""
73
- return _normalize_answer(prediction) == _normalize_answer(ground_truth)
74
-
75
-
76
- def _metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
77
- """Computes the max over all metric scores."""
78
- scores_for_ground_truths = []
79
- for ground_truth in ground_truths:
80
- score = metric_fn(prediction, ground_truth)
81
- scores_for_ground_truths.append(score)
82
- return max(scores_for_ground_truths)
83
-
84
-
85
- def evaluate(dataset, predictions):
86
- """Evaluates predictions for a dataset."""
87
- f1 = exact_match = total = 0
88
- for article in dataset:
89
- for paragraph in article["paragraphs"]:
90
- for qa in paragraph["qas"]:
91
- total += 1
92
- if qa["id"] not in predictions:
93
- message = "Unanswered question " + qa["id"] + " will receive score 0."
94
- logging.error(message)
95
- continue
96
- ground_truths = [entry["text"] for entry in qa["answers"]]
97
- prediction = predictions[qa["id"]]
98
- exact_match += _metric_max_over_ground_truths(_exact_match_score,
99
- prediction, ground_truths)
100
- f1 += _metric_max_over_ground_truths(_f1_score, prediction,
101
- ground_truths)
102
-
103
- exact_match = exact_match / total
104
- f1 = f1 / total
105
-
106
- return {"exact_match": exact_match, "final_f1": f1}