Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
78cdb0e
1
Parent(s):
4ae8a5c
Add test for loading from pickle file
Browse files- test/test.py +27 -0
test/test.py
CHANGED
@@ -309,6 +309,33 @@ class TestPipeline(unittest.TestCase):
|
|
309 |
|
310 |
np.testing.assert_allclose(y_truth, y_test)
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
class TestBest(unittest.TestCase):
|
314 |
def setUp(self):
|
|
|
309 |
|
310 |
np.testing.assert_allclose(y_truth, y_test)
|
311 |
|
312 |
+
def test_load_model_simple(self):
|
313 |
+
# Test that we can simply load a model from its equation file.
|
314 |
+
y = self.X[:, [0, 1]] ** 2
|
315 |
+
model = PySRRegressor(
|
316 |
+
# Test that passing a single operator works:
|
317 |
+
unary_operators="sq(x) = x^2",
|
318 |
+
binary_operators="plus",
|
319 |
+
extra_sympy_mappings={"sq": lambda x: x**2},
|
320 |
+
**self.default_test_kwargs,
|
321 |
+
procs=0,
|
322 |
+
denoise=True,
|
323 |
+
early_stop_condition="stop_if(loss, complexity) = loss < 0.05 && complexity == 2",
|
324 |
+
)
|
325 |
+
rand_dir = Path(tempfile.mkdtemp())
|
326 |
+
equation_file = rand_dir / "equations.csv"
|
327 |
+
model.set_params(temp_equation_file=False)
|
328 |
+
model.set_params(equation_file=equation_file)
|
329 |
+
model.fit(self.X, y)
|
330 |
+
|
331 |
+
# lambda functions are removed from the pickling, so we need
|
332 |
+
# to pass it during the loading:
|
333 |
+
model2 = load(
|
334 |
+
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
335 |
+
)
|
336 |
+
|
337 |
+
np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
|
338 |
+
|
339 |
|
340 |
class TestBest(unittest.TestCase):
|
341 |
def setUp(self):
|