lvwerra HF staff commited on
Commit
05b7663
1 Parent(s): 08e2f88

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. mae.py +3 -21
  2. requirements.txt +1 -1
mae.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  """MAE - Mean Absolute Error Metric"""
15
 
16
- from dataclasses import dataclass
17
- from typing import List, Optional, Union
18
-
19
  import datasets
20
  from sklearn.metrics import mean_absolute_error
21
 
@@ -84,26 +81,13 @@ Examples:
84
  """
85
 
86
 
87
- @dataclass
88
- class MaeConfig(evaluate.info.Config):
89
-
90
- name: str = "default"
91
-
92
- multioutput: str = "uniform_average"
93
- sample_weight: Optional[List[float]] = None
94
-
95
-
96
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
97
  class Mae(evaluate.Metric):
98
- CONFIG_CLASS = MaeConfig
99
- ALLOWED_CONFIG_NAMES = ["default", "multilist"]
100
-
101
- def _info(self, config):
102
  return evaluate.MetricInfo(
103
  description=_DESCRIPTION,
104
  citation=_CITATION,
105
  inputs_description=_KWARGS_DESCRIPTION,
106
- config=config,
107
  features=datasets.Features(self._get_feature_types()),
108
  reference_urls=[
109
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html"
@@ -122,10 +106,8 @@ class Mae(evaluate.Metric):
122
  "references": datasets.Value("float"),
123
  }
124
 
125
- def _compute(self, predictions, references):
126
 
127
- mae_score = mean_absolute_error(
128
- references, predictions, sample_weight=self.config.sample_weight, multioutput=self.config.multioutput
129
- )
130
 
131
  return {"mae": mae_score}
 
13
  # limitations under the License.
14
  """MAE - Mean Absolute Error Metric"""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import mean_absolute_error
18
 
 
81
  """
82
 
83
 
 
 
 
 
 
 
 
 
 
84
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
85
  class Mae(evaluate.Metric):
86
+ def _info(self):
 
 
 
87
  return evaluate.MetricInfo(
88
  description=_DESCRIPTION,
89
  citation=_CITATION,
90
  inputs_description=_KWARGS_DESCRIPTION,
 
91
  features=datasets.Features(self._get_feature_types()),
92
  reference_urls=[
93
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_error.html"
 
106
  "references": datasets.Value("float"),
107
  }
108
 
109
+ def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average"):
110
 
111
+ mae_score = mean_absolute_error(references, predictions, sample_weight=sample_weight, multioutput=multioutput)
 
 
112
 
113
  return {"mae": mae_score}
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
2
  sklearn