Josh98 commited on
Commit
abbc496
·
1 Parent(s): c946cad

change file name

Browse files
Files changed (1) hide show
  1. nl2bash_metric.py +130 -0
nl2bash_metric.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Exact Match metric."""
15
+ import re
16
+ import string
17
+
18
+ import numpy as np
19
+
20
+ import datasets
21
+
22
+
23
+ _DESCRIPTION = """
24
+ Returns the rate at which the input predicted strings exactly match their references, ignoring any strings input as part of the regexes_to_ignore list.
25
+ """
26
+
27
+ _KWARGS_DESCRIPTION = """
28
+ Args:
29
+ predictions: List of predicted texts.
30
+ references: List of reference texts.
31
+ regexes_to_ignore: List, defaults to None. Regex expressions of characters to
32
+ ignore when calculating the exact matches. Note: these regexes are removed
33
+ from the input data before the changes based on the options below (e.g. ignore_case,
34
+ ignore_punctuation, ignore_numbers) are applied.
35
+ ignore_case: Boolean, defaults to False. If true, turns everything
36
+ to lowercase so that capitalization differences are ignored.
37
+ ignore_punctuation: Boolean, defaults to False. If true, removes all punctuation before
38
+ comparing predictions and references.
39
+ ignore_numbers: Boolean, defaults to False. If true, removes all punctuation before
40
+ comparing predictions and references.
41
+ Returns:
42
+ exact_match: Dictionary containing exact_match rate. Possible values are between 0.0 and 100.0, inclusive.
43
+ Examples:
44
+ >>> exact_match = datasets.load_metric("exact_match")
45
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
46
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
47
+ >>> results = exact_match.compute(references=refs, predictions=preds)
48
+ >>> print(round(results["exact_match"], 1))
49
+ 25.0
50
+ >>> exact_match = datasets.load_metric("exact_match")
51
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
52
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
53
+ >>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell"], ignore_case=True, ignore_punctuation=True)
54
+ >>> print(round(results["exact_match"], 1))
55
+ 50.0
56
+ >>> exact_match = datasets.load_metric("exact_match")
57
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
58
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
59
+ >>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True)
60
+ >>> print(round(results["exact_match"], 1))
61
+ 75.0
62
+ >>> exact_match = datasets.load_metric("exact_match")
63
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
64
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
65
+ >>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True, ignore_numbers=True)
66
+ >>> print(round(results["exact_match"], 1))
67
+ 100.0
68
+ >>> exact_match = datasets.load_metric("exact_match")
69
+ >>> refs = ["The cat sat on the mat.", "Theaters are great.", "It's like comparing oranges and apples."]
70
+ >>> preds = ["The cat sat on the mat?", "Theaters are great.", "It's like comparing apples and oranges."]
71
+ >>> results = exact_match.compute(references=refs, predictions=preds)
72
+ >>> print(round(results["exact_match"], 1))
73
+ 33.3
74
+ """
75
+
76
+ _CITATION = """
77
+ """
78
+
79
+
80
+ @datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
81
+ class ExactMatch(datasets.Metric):
82
+ def _info(self):
83
+ return datasets.MetricInfo(
84
+ description=_DESCRIPTION,
85
+ citation=_CITATION,
86
+ inputs_description=_KWARGS_DESCRIPTION,
87
+ features=datasets.Features(
88
+ {
89
+ "predictions": datasets.Value("string", id="sequence"),
90
+ "references": datasets.Value("string", id="sequence"),
91
+ }
92
+ ),
93
+ reference_urls=[],
94
+ )
95
+
96
+ def _compute(
97
+ self,
98
+ predictions,
99
+ references,
100
+ regexes_to_ignore=None,
101
+ ignore_case=False,
102
+ ignore_punctuation=False,
103
+ ignore_numbers=False,
104
+ ):
105
+
106
+ if regexes_to_ignore is not None:
107
+ for s in regexes_to_ignore:
108
+ predictions = np.array([re.sub(s, "", x) for x in predictions])
109
+ references = np.array([re.sub(s, "", x) for x in references])
110
+ else:
111
+ predictions = np.asarray(predictions)
112
+ references = np.asarray(references)
113
+
114
+ if ignore_case:
115
+ predictions = np.char.lower(predictions)
116
+ references = np.char.lower(references)
117
+
118
+ if ignore_punctuation:
119
+ repl_table = string.punctuation.maketrans("", "", string.punctuation)
120
+ predictions = np.char.translate(predictions, table=repl_table)
121
+ references = np.char.translate(references, table=repl_table)
122
+
123
+ if ignore_numbers:
124
+ repl_table = string.digits.maketrans("", "", string.digits)
125
+ predictions = np.char.translate(predictions, table=repl_table)
126
+ references = np.char.translate(references, table=repl_table)
127
+
128
+ score_list = predictions == references
129
+
130
+ return {"exact_match": np.mean(score_list) * 100}