MilesCranmer commited on
Commit
b53e7fa
1 Parent(s): 6501ca0

Add additional test for loading from pickle file

Browse files
Files changed (2) hide show
  1. pysr/sr.py +6 -2
  2. test/test.py +10 -0
pysr/sr.py CHANGED
@@ -926,7 +926,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
926
 
927
  def _checkpoint(self):
928
  """Saves the model's current state to a checkpoint file.
929
-
930
  This should only be used internally by PySRRegressor."""
931
  # Save model state:
932
  self.show_pickle_warnings_ = False
@@ -2132,8 +2132,12 @@ def load(
2132
  assert n_features_in is None
2133
  with open(str(equation_file) + ".pkl", "rb") as f:
2134
  model = pkl.load(f)
 
 
2135
  model.set_params(**pysr_kwargs)
2136
- model.refresh()
 
 
2137
  return model
2138
 
2139
  # Else, we re-create it.
 
926
 
927
  def _checkpoint(self):
928
  """Saves the model's current state to a checkpoint file.
929
+
930
  This should only be used internally by PySRRegressor."""
931
  # Save model state:
932
  self.show_pickle_warnings_ = False
 
2132
  assert n_features_in is None
2133
  with open(str(equation_file) + ".pkl", "rb") as f:
2134
  model = pkl.load(f)
2135
+ # Update any parameters if necessary, such as
2136
+ # extra_sympy_mappings:
2137
  model.set_params(**pysr_kwargs)
2138
+ if "equations_" not in model.__dict__ or model.equations_ is None:
2139
+ model.refresh()
2140
+
2141
  return model
2142
 
2143
  # Else, we re-create it.
test/test.py CHANGED
@@ -336,6 +336,16 @@ class TestPipeline(unittest.TestCase):
336
 
337
  np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
338
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  class TestBest(unittest.TestCase):
341
  def setUp(self):
 
336
 
337
  np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
338
 
339
+ # Try again, but using only the pickle file:
340
+ for file_to_delete in [str(equation_file), str(equation_file) + ".bkup"]:
341
+ if os.path.exists(file_to_delete):
342
+ os.remove(file_to_delete)
343
+
344
+ model3 = load(
345
+ model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
346
+ )
347
+ np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
348
+
349
 
350
  class TestBest(unittest.TestCase):
351
  def setUp(self):