MilesCranmer commited on
Commit
27fac96
·
1 Parent(s): b2f8a6f

Clean up testing code

Browse files
Files changed (1) hide show
  1. test/test.py +6 -12
test/test.py CHANGED
@@ -1,6 +1,7 @@
1
  import inspect
2
  import unittest
3
  import numpy as np
 
4
  from pysr import PySRRegressor
5
  from pysr.sr import run_feature_selection, _handle_feature_selection
6
  from sklearn.utils.estimator_checks import check_estimator
@@ -166,18 +167,15 @@ class TestPipeline(unittest.TestCase):
166
  unary_operators="sq(x) = x^2",
167
  binary_operators="plus",
168
  extra_sympy_mappings={"sq": lambda x: x**2},
169
- **{
170
- k: v
171
- for k, v in self.default_test_kwargs.items()
172
- if k != "model_selection"
173
- },
174
  procs=0,
175
  denoise=True,
176
  early_stop_condition="stop_if(loss, complexity) = loss < 0.05 && complexity == 2",
177
- model_selection="best",
178
  )
 
 
 
179
  model.fit(self.X, y)
180
- print(model)
181
  self.assertLessEqual(model.get_best()[1]["loss"], 1e-2)
182
  self.assertLessEqual(model.get_best()[1]["loss"], 1e-2)
183
 
@@ -326,10 +324,6 @@ class TestFeatureSelection(unittest.TestCase):
326
  class TestMiscellaneous(unittest.TestCase):
327
  """Test miscellaneous functions."""
328
 
329
- def setUp(self):
330
- # Allows all scikit-learn exception messages to be read.
331
- self.maxDiff = None
332
-
333
  def test_deprecation(self):
334
  """Ensure that deprecation works as expected.
335
 
@@ -344,7 +338,7 @@ class TestMiscellaneous(unittest.TestCase):
344
 
345
  def test_size_warning(self):
346
  """Ensure that a warning is given for a large input size."""
347
- model = PySRRegressor(max_evals=10000, populations=2)
348
  X = np.random.randn(10001, 2)
349
  y = np.random.randn(10001)
350
  with warnings.catch_warnings():
 
1
  import inspect
2
  import unittest
3
  import numpy as np
4
+ from sklearn import model_selection
5
  from pysr import PySRRegressor
6
  from pysr.sr import run_feature_selection, _handle_feature_selection
7
  from sklearn.utils.estimator_checks import check_estimator
 
167
  unary_operators="sq(x) = x^2",
168
  binary_operators="plus",
169
  extra_sympy_mappings={"sq": lambda x: x**2},
170
+ **self.default_test_kwargs,
 
 
 
 
171
  procs=0,
172
  denoise=True,
173
  early_stop_condition="stop_if(loss, complexity) = loss < 0.05 && complexity == 2",
 
174
  )
175
+ # We expect in this case that the "best"
176
+ # equation should be the right one:
177
+ model.set_params(model_selection="best")
178
  model.fit(self.X, y)
 
179
  self.assertLessEqual(model.get_best()[1]["loss"], 1e-2)
180
  self.assertLessEqual(model.get_best()[1]["loss"], 1e-2)
181
 
 
324
  class TestMiscellaneous(unittest.TestCase):
325
  """Test miscellaneous functions."""
326
 
 
 
 
 
327
  def test_deprecation(self):
328
  """Ensure that deprecation works as expected.
329
 
 
338
 
339
  def test_size_warning(self):
340
  """Ensure that a warning is given for a large input size."""
341
+ model = PySRRegressor()
342
  X = np.random.randn(10001, 2)
343
  y = np.random.randn(10001)
344
  with warnings.catch_warnings():