tttc3 commited on
Commit
c51257e
1 Parent(s): 4b56660

Fixed weight checking

Browse files
Files changed (1) hide show
  1. pysr/sr.py +4 -4
pysr/sr.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import sys
3
  import numpy as np
4
  import pandas as pd
5
- from sklearn.utils import check_array, check_random_state
6
  import sympy
7
  from sympy import sympify
8
  import re
@@ -15,7 +15,7 @@ from multiprocessing import cpu_count
15
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
16
  from sklearn.utils.validation import (
17
  _check_feature_names_in,
18
- _check_sample_weight,
19
  check_is_fitted,
20
  )
21
 
@@ -1073,7 +1073,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1073
  if Xresampled is not None:
1074
  Xresampled = check_array(Xresampled)
1075
  if weights is not None:
1076
- weights = _check_sample_weight(weights, y)
 
1077
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1078
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1079
  variable_names = self.feature_names_in_
@@ -1461,7 +1462,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1461
 
1462
  mutated_params = self._validate_init_params()
1463
 
1464
- # Parameter input validation (for parameters defined in __init__)
1465
  X, y, Xresampled, weights, variable_names = self._validate_fit_params(
1466
  X, y, Xresampled, weights, variable_names
1467
  )
 
2
  import sys
3
  import numpy as np
4
  import pandas as pd
5
+ from sklearn.utils import check_array, check_consistent_length, check_random_state
6
  import sympy
7
  from sympy import sympify
8
  import re
 
15
  from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
16
  from sklearn.utils.validation import (
17
  _check_feature_names_in,
18
+ check_X_y,
19
  check_is_fitted,
20
  )
21
 
 
1073
  if Xresampled is not None:
1074
  Xresampled = check_array(Xresampled)
1075
  if weights is not None:
1076
+ weights = check_array(weights)
1077
+ check_consistent_length(weights, y)
1078
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1079
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
1080
  variable_names = self.feature_names_in_
 
1462
 
1463
  mutated_params = self._validate_init_params()
1464
 
 
1465
  X, y, Xresampled, weights, variable_names = self._validate_fit_params(
1466
  X, y, Xresampled, weights, variable_names
1467
  )