Natooz commited on
Commit
3945a39
·
verified ·
1 Parent(s): 4f846cc

fixing call to `mean_squared_error` to comply with the newer sklearn versions

Browse files
Files changed (1) hide show
  1. mse.py +9 -4
mse.py CHANGED
@@ -14,7 +14,7 @@
14
  """MSE - Mean Squared Error Metric"""
15
 
16
  import datasets
17
- from sklearn.metrics import mean_squared_error
18
 
19
  import evaluate
20
 
@@ -112,8 +112,13 @@ class Mse(evaluate.Metric):
112
 
113
  def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
- mse = mean_squared_error(
116
- references, predictions, sample_weight=sample_weight, multioutput=multioutput, squared=squared
117
- )
 
 
 
 
 
118
 
119
  return {"mse": mse}
 
14
  """MSE - Mean Squared Error Metric"""
15
 
16
  import datasets
17
+ from sklearn.metrics import mean_squared_error, root_mean_squared_error
18
 
19
  import evaluate
20
 
 
112
 
113
  def _compute(self, predictions, references, sample_weight=None, multioutput="uniform_average", squared=True):
114
 
115
+ if squared:
116
+ mse = mean_squared_error(
117
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput
118
+ )
119
+ else:
120
+ mse = root_mean_squared_error(
121
+ references, predictions, sample_weight=sample_weight, multioutput=multioutput
122
+ )
123
 
124
  return {"mse": mse}