Spaces:
Sleeping
Sleeping
tttc3
commited on
Commit
•
c51257e
1
Parent(s):
4b56660
Fixed weight checking
Browse files- pysr/sr.py +4 -4
pysr/sr.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import sys
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
-
from sklearn.utils import check_array, check_random_state
|
6 |
import sympy
|
7 |
from sympy import sympify
|
8 |
import re
|
@@ -15,7 +15,7 @@ from multiprocessing import cpu_count
|
|
15 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
16 |
from sklearn.utils.validation import (
|
17 |
_check_feature_names_in,
|
18 |
-
|
19 |
check_is_fitted,
|
20 |
)
|
21 |
|
@@ -1073,7 +1073,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1073 |
if Xresampled is not None:
|
1074 |
Xresampled = check_array(Xresampled)
|
1075 |
if weights is not None:
|
1076 |
-
weights =
|
|
|
1077 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
1078 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
1079 |
variable_names = self.feature_names_in_
|
@@ -1461,7 +1462,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1461 |
|
1462 |
mutated_params = self._validate_init_params()
|
1463 |
|
1464 |
-
# Parameter input validation (for parameters defined in __init__)
|
1465 |
X, y, Xresampled, weights, variable_names = self._validate_fit_params(
|
1466 |
X, y, Xresampled, weights, variable_names
|
1467 |
)
|
|
|
2 |
import sys
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
+
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
6 |
import sympy
|
7 |
from sympy import sympify
|
8 |
import re
|
|
|
15 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
16 |
from sklearn.utils.validation import (
|
17 |
_check_feature_names_in,
|
18 |
+
check_X_y,
|
19 |
check_is_fitted,
|
20 |
)
|
21 |
|
|
|
1073 |
if Xresampled is not None:
|
1074 |
Xresampled = check_array(Xresampled)
|
1075 |
if weights is not None:
|
1076 |
+
weights = check_array(weights)
|
1077 |
+
check_consistent_length(weights, y)
|
1078 |
X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
|
1079 |
self.feature_names_in_ = _check_feature_names_in(self, variable_names)
|
1080 |
variable_names = self.feature_names_in_
|
|
|
1462 |
|
1463 |
mutated_params = self._validate_init_params()
|
1464 |
|
|
|
1465 |
X, y, Xresampled, weights, variable_names = self._validate_fit_params(
|
1466 |
X, y, Xresampled, weights, variable_names
|
1467 |
)
|