lvwerra HF staff commited on
Commit
08e2f88
·
1 Parent(s): a299c46

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. mae.py +21 -3
  2. requirements.txt +1 -1
mae.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """MAE - Mean Absolute Error Metric"""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import mean_absolute_error
18
 
@@ -81,13 +84,26 @@ Examples:
81
  """
82
 
83
 
 
 
 
 
 
 
 
 
 
84
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
85
  class Mae(evaluate.Metric):
86
- def _info(self):
 
 
 
87
  return evaluate.MetricInfo(
88
  description=_DESCRIPTION,
89
  citation=_CITATION,
90
  inputs_description=_KWARGS_DESCRIPTION,
 
91
  features=datasets.Features(self._get_feature_types()),
92
  reference_urls=[
93
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html"
@@ -106,8 +122,10 @@ class Mae(evaluate.Metric):
106
  "references": datasets.Value("float"),
107
  }
108
 
109
- def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average"):
110
 
111
- mae_score = mean_absolute_error(references, predictions, sample_weight=sample_weight, multioutput=multioutput)
 
 
112
 
113
  return {"mae": mae_score}
 
13
  # limitations under the License.
14
  """MAE - Mean Absolute Error Metric"""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional, Union
18
+
19
  import datasets
20
  from sklearn.metrics import mean_absolute_error
21
 
 
84
  """
85
 
86
 
87
+ @dataclass
88
+ class MaeConfig(evaluate.info.Config):
89
+
90
+ name: str = "default"
91
+
92
+ multioutput: str = "uniform_average"
93
+ sample_weight: Optional[List[float]] = None
94
+
95
+
96
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
97
  class Mae(evaluate.Metric):
98
+ CONFIG_CLASS = MaeConfig
99
+ ALLOWED_CONFIG_NAMES = ["default", "multilist"]
100
+
101
+ def _info(self, config):
102
  return evaluate.MetricInfo(
103
  description=_DESCRIPTION,
104
  citation=_CITATION,
105
  inputs_description=_KWARGS_DESCRIPTION,
106
+ config=config,
107
  features=datasets.Features(self._get_feature_types()),
108
  reference_urls=[
109
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html"
 
122
  "references": datasets.Value("float"),
123
  }
124
 
125
+ def _compute(self, predictions, references):
126
 
127
+ mae_score = mean_absolute_error(
128
+ references, predictions, sample_weight=self.config.sample_weight, multioutput=self.config.multioutput
129
+ )
130
 
131
  return {"mae": mae_score}
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn