tybrs commited on
Commit
820d9c2
·
1 Parent(s): a09e327

Create glue.py

Browse files
Files changed (1) hide show
  1. glue.py +153 -0
glue.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Evaluate Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ GLUE benchmark metric. """
15
+
16
+ import datasets
17
+ from scipy.stats import pearsonr, spearmanr
18
+ from sklearn.metrics import f1_score, matthews_corrcoef
19
+
20
+ import evaluate
21
+
22
+
23
+ _CITATION = """\
24
+ @inproceedings{wang2019glue,
25
+ title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},
26
+ author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},
27
+ note={In the Proceedings of ICLR.},
28
+ year={2019}
29
+ }
30
+ """
31
+
32
+ _DESCRIPTION = """\
33
+ GLUE, the General Language Understanding Evaluation benchmark
34
+ (https://gluebenchmark.com/) is a collection of resources for training,
35
+ evaluating, and analyzing natural language understanding systems.
36
+ """
37
+
38
+ _KWARGS_DESCRIPTION = """
39
+ Compute GLUE evaluation metric associated to each GLUE dataset.
40
+ Args:
41
+ predictions: list of predictions to score.
42
+ Each translation should be tokenized into a list of tokens.
43
+ references: list of lists of references for each translation.
44
+ Each reference should be tokenized into a list of tokens.
45
+ Returns: depending on the GLUE subset, one or several of:
46
+ "accuracy": Accuracy
47
+ "f1": F1 score
48
+ "pearson": Pearson Correlation
49
+ "spearmanr": Spearman Correlation
50
+ "matthews_correlation": Matthew Correlation
51
+ Examples:
52
+ >>> glue_metric = evaluate.load('glue', 'sst2') # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
53
+ >>> references = [0, 1]
54
+ >>> predictions = [0, 1]
55
+ >>> results = glue_metric.compute(predictions=predictions, references=references)
56
+ >>> print(results)
57
+ {'accuracy': 1.0}
58
+ >>> glue_metric = evaluate.load('glue', 'mrpc') # 'mrpc' or 'qqp'
59
+ >>> references = [0, 1]
60
+ >>> predictions = [0, 1]
61
+ >>> results = glue_metric.compute(predictions=predictions, references=references)
62
+ >>> print(results)
63
+ {'accuracy': 1.0, 'f1': 1.0}
64
+ >>> glue_metric = evaluate.load('glue', 'stsb')
65
+ >>> references = [0., 1., 2., 3., 4., 5.]
66
+ >>> predictions = [0., 1., 2., 3., 4., 5.]
67
+ >>> results = glue_metric.compute(predictions=predictions, references=references)
68
+ >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)})
69
+ {'pearson': 1.0, 'spearmanr': 1.0}
70
+ >>> glue_metric = evaluate.load('glue', 'cola')
71
+ >>> references = [0, 1]
72
+ >>> predictions = [0, 1]
73
+ >>> results = glue_metric.compute(predictions=predictions, references=references)
74
+ >>> print(results)
75
+ {'matthews_correlation': 1.0}
76
+ """
77
+
78
+
79
+ def simple_accuracy(preds, labels):
80
+ return float((preds == labels).mean())
81
+
82
+
83
+ def acc_and_f1(preds, labels):
84
+ acc = simple_accuracy(preds, labels)
85
+ f1 = float(f1_score(y_true=labels, y_pred=preds))
86
+ return {
87
+ "accuracy": acc,
88
+ "f1": f1,
89
+ }
90
+
91
+
92
+ def pearson_and_spearman(preds, labels):
93
+ pearson_corr = float(pearsonr(preds, labels)[0])
94
+ spearman_corr = float(spearmanr(preds, labels)[0])
95
+ return {
96
+ "pearson": pearson_corr,
97
+ "spearmanr": spearman_corr,
98
+ }
99
+
100
+
101
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
102
+ class Glue(evaluate.Metric):
103
+ def _info(self):
104
+ if self.config_name not in [
105
+ "sst2",
106
+ "mnli",
107
+ "mnli_mismatched",
108
+ "mnli_matched",
109
+ "cola",
110
+ "stsb",
111
+ "mrpc",
112
+ "qqp",
113
+ "qnli",
114
+ "rte",
115
+ "wnli",
116
+ "hans",
117
+ ]:
118
+ raise KeyError(
119
+ "You should supply a configuration name selected in "
120
+ '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
121
+ '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
122
+ )
123
+ return evaluate.MetricInfo(
124
+ description=_DESCRIPTION,
125
+ citation=_CITATION,
126
+ inputs_description=_KWARGS_DESCRIPTION,
127
+ features=datasets.Features(
128
+ {
129
+ "predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
130
+ "references": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
131
+ }
132
+ ),
133
+ codebase_urls=[],
134
+ reference_urls=[],
135
+ format="numpy",
136
+ )
137
+
138
+ def _compute(self, predictions, references, config_name=None):
139
+ self.config_name = config_name
140
+ if self.config_name == "cola":
141
+ return {"matthews_correlation": matthews_corrcoef(references, predictions)}
142
+ elif self.config_name == "stsb":
143
+ return pearson_and_spearman(predictions, references)
144
+ elif self.config_name in ["mrpc", "qqp"]:
145
+ return acc_and_f1(predictions, references)
146
+ elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
147
+ return {"accuracy": simple_accuracy(predictions, references)}
148
+ else:
149
+ raise KeyError(
150
+ "You should supply a configuration name selected in "
151
+ '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
152
+ '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
153
+ )