MilesCranmer commited on
Commit
e274713
1 Parent(s): ec8124e

Fix test for PySRRegressor

Browse files
Files changed (2) hide show
  1. pysr/sr.py +1 -1
  2. test/test.py +7 -16
pysr/sr.py CHANGED
@@ -914,7 +914,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
914
  y = np.stack(
915
  [
916
  _denoise(X, y[:, i], Xresampled=Xresampled)[1]
917
- for i in range(nout)
918
  ],
919
  axis=1,
920
  )
 
914
  y = np.stack(
915
  [
916
  _denoise(X, y[:, i], Xresampled=Xresampled)[1]
917
+ for i in range(self.nout)
918
  ],
919
  axis=1,
920
  )
test/test.py CHANGED
@@ -91,7 +91,7 @@ class TestPipeline(unittest.TestCase):
91
  self.assertTrue("None" not in regressor.__repr__())
92
  self.assertTrue(">>>>" in regressor.__repr__())
93
 
94
- self.assertLessEqual(regressor.equations.iloc[-1]["MSE"], 1e-4)
95
  np.testing.assert_almost_equal(regressor.predict(X), y, decimal=1)
96
 
97
  # Tweak model selection:
@@ -181,18 +181,17 @@ class TestBest(unittest.TestCase):
181
  "equation_file.csv.bkup", sep="|"
182
  )
183
 
184
- self.equations = get_hof(
185
- "equation_file.csv",
186
- n_features=2,
187
  variables_names="x0 x1".split(" "),
188
  extra_sympy_mappings={},
189
  output_jax_format=False,
190
  multioutput=False,
191
  nout=1,
192
  )
193
-
194
- self.model = PySRRegressor()
195
- self.model.equations = self.equations
196
 
197
  def test_best(self):
198
  self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
@@ -232,12 +231,4 @@ class TestFeatureSelection(unittest.TestCase):
232
  self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
233
  np.testing.assert_array_equal(
234
  np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)
235
- )
236
-
237
-
238
- class TestHelperFunctions(unittest.TestCase):
239
- @patch("builtins.input", side_effect=["y", "n"])
240
- def test_yesno(self, mock_input):
241
- # Assert that the yes/no function correctly deals with y/n
242
- self.assertEqual(_yesno("Test"), True)
243
- self.assertEqual(_yesno("Test"), False)
 
91
  self.assertTrue("None" not in regressor.__repr__())
92
  self.assertTrue(">>>>" in regressor.__repr__())
93
 
94
+ self.assertLessEqual(regressor.equations.iloc[-1]["loss"], 1e-4)
95
  np.testing.assert_almost_equal(regressor.predict(X), y, decimal=1)
96
 
97
  # Tweak model selection:
 
181
  "equation_file.csv.bkup", sep="|"
182
  )
183
 
184
+ self.model = PySRRegressor(
185
+ equation_file="equation_file.csv",
 
186
  variables_names="x0 x1".split(" "),
187
  extra_sympy_mappings={},
188
  output_jax_format=False,
189
  multioutput=False,
190
  nout=1,
191
  )
192
+ self.model.n_features = 2
193
+ self.model.refresh()
194
+ self.equations = self.model.equations
195
 
196
  def test_best(self):
197
  self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
 
231
  self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
232
  np.testing.assert_array_equal(
233
  np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)
234
+ )