Josh98 commited on
Commit
a0a0756
·
1 Parent(s): dcc6afa

change to use exact match first

Browse files
Files changed (1) hide show
  1. nl2bash_m.py +93 -57
nl2bash_m.py CHANGED
@@ -11,85 +11,121 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
- """TODO: Add a description here."""
 
 
15
 
16
- import evaluate
17
  import datasets
 
18
 
 
19
 
20
- # TODO: Add BibTeX citation
21
- _CITATION = """\
22
- @InProceedings{huggingface:module,
23
- title = {A great new module},
24
- authors={huggingface, Inc.},
25
- year={2020}
26
- }
27
- """
28
 
29
- # TODO: Add description of the module here
30
- _DESCRIPTION = """\
31
- This new module is designed to solve this great ML task and is crafted with a lot of care.
32
  """
33
 
34
-
35
- # TODO: Add description of the arguments of the module here
36
  _KWARGS_DESCRIPTION = """
37
- Calculates how good are predictions given some references, using certain scores
38
  Args:
39
- predictions: list of predictions to score. Each predictions
40
- should be a string with tokens separated by spaces.
41
- references: list of reference for each prediction. Each
42
- reference should be a string with tokens separated by spaces.
 
 
 
 
 
 
 
 
43
  Returns:
44
- accuracy: description of the first score,
45
- another_score: description of the second score,
46
  Examples:
47
- Examples should be written in doctest format, and should illustrate how
48
- to use the function.
49
-
50
- >>> my_new_module = evaluate.load("my_new_module")
51
- >>> results = my_new_module.compute(references=[0, 1], predictions=[0, 1])
52
- >>> print(results)
53
- {'accuracy': 1.0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  """
55
 
56
- # TODO: Define external resources urls if needed
57
- BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
58
 
59
 
60
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
61
  class nl2bash_m(evaluate.Metric):
62
- """TODO: Short description of my evaluation module."""
63
-
64
  def _info(self):
65
- # TODO: Specifies the evaluate.EvaluationModuleInfo object
66
  return evaluate.MetricInfo(
67
- # This is the description that will appear on the modules page.
68
- module_type="metric",
69
  description=_DESCRIPTION,
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
- # This defines the format of each prediction and reference
73
- features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
- }),
77
- # Homepage of the module for documentation
78
- homepage="http://module.homepage",
79
- # Additional links to the codebase or references
80
- codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
81
- reference_urls=["http://path.to.reference.url/new_module"]
82
  )
83
 
84
- def _download_and_prepare(self, dl_manager):
85
- """Optional: download external resources useful to compute the scores"""
86
- # TODO: Download external resources if needed
87
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- def _compute(self, predictions, references):
90
- """Returns the scores"""
91
- # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
93
- return {
94
- "accuracy": accuracy,
95
- }
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+ """Exact Match metric."""
15
+ import re
16
+ import string
17
 
 
18
  import datasets
19
+ import numpy as np
20
 
21
+ import evaluate
22
 
 
 
 
 
 
 
 
 
23
 
24
+ _DESCRIPTION = """
25
+ Returns the rate at which the input predicted strings exactly match their references, ignoring any strings input as part of the regexes_to_ignore list.
 
26
  """
27
 
 
 
28
  _KWARGS_DESCRIPTION = """
 
29
  Args:
30
+ predictions: List of predicted texts.
31
+ references: List of reference texts.
32
+ regexes_to_ignore: List, defaults to None. Regex expressions of characters to
33
+ ignore when calculating the exact matches. Note: these regexes are removed
34
+ from the input data before the changes based on the options below (e.g. ignore_case,
35
+ ignore_punctuation, ignore_numbers) are applied.
36
+ ignore_case: Boolean, defaults to False. If true, turns everything
37
+ to lowercase so that capitalization differences are ignored.
38
+ ignore_punctuation: Boolean, defaults to False. If true, removes all punctuation before
39
+ comparing predictions and references.
40
+ ignore_numbers: Boolean, defaults to False. If true, removes all punctuation before
41
+ comparing predictions and references.
42
  Returns:
43
+ exact_match: Dictionary containing exact_match rate. Possible values are between 0.0 and 1.0, inclusive.
 
44
  Examples:
45
+ >>> exact_match = evaluate.load("exact_match")
46
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
47
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
48
+ >>> results = exact_match.compute(references=refs, predictions=preds)
49
+ >>> print(round(results["exact_match"], 2))
50
+ 0.25
51
+ >>> exact_match = evaluate.load("exact_match")
52
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
53
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
54
+ >>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell"], ignore_case=True, ignore_punctuation=True)
55
+ >>> print(round(results["exact_match"], 2))
56
+ 0.5
57
+ >>> exact_match = evaluate.load("exact_match")
58
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
59
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
60
+ >>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True)
61
+ >>> print(round(results["exact_match"], 2))
62
+ 0.75
63
+ >>> exact_match = evaluate.load("exact_match")
64
+ >>> refs = ["the cat", "theater", "YELLING", "agent007"]
65
+ >>> preds = ["cat?", "theater", "yelling", "agent"]
66
+ >>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True, ignore_numbers=True)
67
+ >>> print(round(results["exact_match"], 2))
68
+ 1.0
69
+ >>> exact_match = evaluate.load("exact_match")
70
+ >>> refs = ["The cat sat on the mat.", "Theaters are great.", "It's like comparing oranges and apples."]
71
+ >>> preds = ["The cat sat on the mat?", "Theaters are great.", "It's like comparing apples and oranges."]
72
+ >>> results = exact_match.compute(references=refs, predictions=preds)
73
+ >>> print(round(results["exact_match"], 2))
74
+ 0.33
75
  """
76
 
77
+ _CITATION = """
78
+ """
79
 
80
 
81
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
82
  class nl2bash_m(evaluate.Metric):
 
 
83
  def _info(self):
 
84
  return evaluate.MetricInfo(
 
 
85
  description=_DESCRIPTION,
86
  citation=_CITATION,
87
  inputs_description=_KWARGS_DESCRIPTION,
88
+ features=datasets.Features(
89
+ {
90
+ "predictions": datasets.Value("string", id="sequence"),
91
+ "references": datasets.Value("string", id="sequence"),
92
+ }
93
+ ),
94
+ reference_urls=[],
 
 
 
95
  )
96
 
97
+ def _compute(
98
+ self,
99
+ predictions,
100
+ references,
101
+ regexes_to_ignore=None,
102
+ ignore_case=False,
103
+ ignore_punctuation=False,
104
+ ignore_numbers=False,
105
+ ):
106
+
107
+ if regexes_to_ignore is not None:
108
+ for s in regexes_to_ignore:
109
+ predictions = np.array([re.sub(s, "", x) for x in predictions])
110
+ references = np.array([re.sub(s, "", x) for x in references])
111
+ else:
112
+ predictions = np.asarray(predictions)
113
+ references = np.asarray(references)
114
+
115
+ if ignore_case:
116
+ predictions = np.char.lower(predictions)
117
+ references = np.char.lower(references)
118
+
119
+ if ignore_punctuation:
120
+ repl_table = string.punctuation.maketrans("", "", string.punctuation)
121
+ predictions = np.char.translate(predictions, table=repl_table)
122
+ references = np.char.translate(references, table=repl_table)
123
+
124
+ if ignore_numbers:
125
+ repl_table = string.digits.maketrans("", "", string.digits)
126
+ predictions = np.char.translate(predictions, table=repl_table)
127
+ references = np.char.translate(references, table=repl_table)
128
+
129
+ score_list = predictions == references
130
 
131
+ return {"exact_match": np.mean(score_list)}