lvwerra HF staff commited on
Commit
2cd389f
·
1 Parent(s): 67f1b04

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. mse.py +25 -3
  2. requirements.txt +1 -1
mse.py CHANGED
@@ -13,6 +13,9 @@
13
  # limitations under the License.
14
  """MSE - Mean Squared Error Metric"""
15
 
 
 
 
16
  import datasets
17
  from sklearn.metrics import mean_squared_error
18
 
@@ -85,13 +88,28 @@ Examples:
85
  """
86
 
87
 
 
 
 
 
 
 
 
 
 
 
88
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
89
  class Mse(evaluate.Metric):
90
- def _info(self):
 
 
 
 
91
  return evaluate.MetricInfo(
92
  description=_DESCRIPTION,
93
  citation=_CITATION,
94
  inputs_description=_KWARGS_DESCRIPTION,
 
95
  features=datasets.Features(self._get_feature_types()),
96
  reference_urls=[
97
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html"
@@ -110,10 +128,14 @@ class Mse(evaluate.Metric):
110
  "references": datasets.Value("float"),
111
  }
112
 
113
- def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
  mse = mean_squared_error(
116
- references, predictions, sample_weight=sample_weight, multioutput=multioutput, squared=squared
 
 
 
 
117
  )
118
 
119
  return {"mse": mse}
 
13
  # limitations under the License.
14
  """MSE - Mean Squared Error Metric"""
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
+
19
  import datasets
20
  from sklearn.metrics import mean_squared_error
21
 
 
88
  """
89
 
90
 
91
+ @dataclass
92
+ class MseConfig(evaluate.info.Config):
93
+
94
+ name: str = "default"
95
+
96
+ multioutput: str = "uniform_average"
97
+ sample_weight: Optional[List[float]] = None
98
+ squared: bool = True
99
+
100
+
101
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
102
  class Mse(evaluate.Metric):
103
+
104
+ CONFIG_CLASS = MseConfig
105
+ ALLOWED_CONFIG_NAMES = ["default", "multilist"]
106
+
107
+ def _info(self, config):
108
  return evaluate.MetricInfo(
109
  description=_DESCRIPTION,
110
  citation=_CITATION,
111
  inputs_description=_KWARGS_DESCRIPTION,
112
+ config=config,
113
  features=datasets.Features(self._get_feature_types()),
114
  reference_urls=[
115
  "https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_squared_error.html"
 
128
  "references": datasets.Value("float"),
129
  }
130
 
131
+ def _compute(self, predictions, references):
132
 
133
  mse = mean_squared_error(
134
+ references,
135
+ predictions,
136
+ sample_weight=self.config.sample_weight,
137
+ multioutput=self.config.multioutput,
138
+ squared=self.config.squared,
139
  )
140
 
141
  return {"mse": mse}
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn