MilesCranmer commited on
Commit
00875eb
1 Parent(s): ba6e296

Hotfix for 1D weights

Browse files
Files changed (2) hide show
  1. pysr/sr.py +1 -1
  2. test/test.py +11 -0
pysr/sr.py CHANGED
@@ -1127,7 +1127,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1127
  if Xresampled is not None:
1128
  Xresampled = check_array(Xresampled)
1129
  if weights is not None:
1130
- weights = check_array(weights)
1131
  check_consistent_length(weights, y)
1132
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1133
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
 
1127
  if Xresampled is not None:
1128
  Xresampled = check_array(Xresampled)
1129
  if weights is not None:
1130
+ weights = check_array(weights, ensure_2d=False)
1131
  check_consistent_length(weights, y)
1132
  X, y = self._validate_data(X=X, y=y, reset=True, multi_output=True)
1133
  self.feature_names_in_ = _check_feature_names_in(self, variable_names)
test/test.py CHANGED
@@ -39,6 +39,17 @@ class TestPipeline(unittest.TestCase):
39
  print(model.equations_)
40
  self.assertLessEqual(model.get_best()["loss"], 1e-4)
41
 
 
 
 
 
 
 
 
 
 
 
 
42
  def test_multiprocessing(self):
43
  y = self.X[:, 0]
44
  model = PySRRegressor(
 
39
  print(model.equations_)
40
  self.assertLessEqual(model.get_best()["loss"], 1e-4)
41
 
42
+ def test_linear_relation_weighted(self):
43
+ y = self.X[:, 0]
44
+ weights = np.ones_like(y)
45
+ model = PySRRegressor(
46
+ **self.default_test_kwargs,
47
+ early_stop_condition="stop_if(loss, complexity) = loss < 1e-4 && complexity == 1",
48
+ )
49
+ model.fit(self.X, y, weights=weights)
50
+ print(model.equations_)
51
+ self.assertLessEqual(model.get_best()["loss"], 1e-4)
52
+
53
  def test_multiprocessing(self):
54
  y = self.X[:, 0]
55
  model = PySRRegressor(