MilesCranmer commited on
Commit
78cdb0e
1 Parent(s): 4ae8a5c

Add test for loading from pickle file

Browse files
Files changed (1) hide show
  1. test/test.py +27 -0
test/test.py CHANGED
@@ -309,6 +309,33 @@ class TestPipeline(unittest.TestCase):
309
 
310
  np.testing.assert_allclose(y_truth, y_test)
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  class TestBest(unittest.TestCase):
314
  def setUp(self):
 
309
 
310
  np.testing.assert_allclose(y_truth, y_test)
311
 
312
+ def test_load_model_simple(self):
313
+ # Test that we can simply load a model from its equation file.
314
+ y = self.X[:, [0, 1]] ** 2
315
+ model = PySRRegressor(
316
+ # Test that passing a single operator works:
317
+ unary_operators="sq(x) = x^2",
318
+ binary_operators="plus",
319
+ extra_sympy_mappings={"sq": lambda x: x**2},
320
+ **self.default_test_kwargs,
321
+ procs=0,
322
+ denoise=True,
323
+ early_stop_condition="stop_if(loss, complexity) = loss < 0.05 && complexity == 2",
324
+ )
325
+ rand_dir = Path(tempfile.mkdtemp())
326
+ equation_file = rand_dir / "equations.csv"
327
+ model.set_params(temp_equation_file=False)
328
+ model.set_params(equation_file=equation_file)
329
+ model.fit(self.X, y)
330
+
331
+ # lambda functions are removed from the pickling, so we need
332
+ # to pass it during the loading:
333
+ model2 = load(
334
+ model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
335
+ )
336
+
337
+ np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
338
+
339
 
340
  class TestBest(unittest.TestCase):
341
  def setUp(self):