MilesCranmer commited on
Commit
6501ca0
1 Parent(s): c6902b7

Checkpoint model before and after fit

Browse files
Files changed (1) hide show
  1. pysr/sr.py +22 -7
pysr/sr.py CHANGED
@@ -924,6 +924,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
924
  ]
925
  return pickled_state
926
 
 
 
 
 
 
 
 
 
 
 
927
  @property
928
  def equations(self): # pragma: no cover
929
  warnings.warn(
@@ -1624,13 +1634,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1624
  y,
1625
  )
1626
 
1627
- # Save model state:
1628
- self.show_pickle_warnings_ = False
1629
- with open(str(self.equation_file_) + ".pkl", "wb") as f:
1630
- pkl.dump(self, f)
1631
- self.show_pickle_warnings_ = True
1632
- # Fitting procedure
1633
- return self._run(X, y, mutated_params, weights=weights, seed=seed)
 
 
 
 
 
1634
 
1635
  def refresh(self, checkpoint_file=None):
1636
  """
 
924
  ]
925
  return pickled_state
926
 
927
+ def _checkpoint(self):
928
+ """Saves the model's current state to a checkpoint file.
929
+
930
+ This should only be used internally by PySRRegressor."""
931
+ # Save model state:
932
+ self.show_pickle_warnings_ = False
933
+ with open(str(self.equation_file_) + ".pkl", "wb") as f:
934
+ pkl.dump(self, f)
935
+ self.show_pickle_warnings_ = True
936
+
937
  @property
938
  def equations(self): # pragma: no cover
939
  warnings.warn(
 
1634
  y,
1635
  )
1636
 
1637
+ # Initially, just save model parameters, so that
1638
+ # it can be loaded from an early exit:
1639
+ self._checkpoint()
1640
+
1641
+ # Perform the search:
1642
+ self._run(X, y, mutated_params, weights=weights, seed=seed)
1643
+
1644
+ # Then, after fit, we save again, so the pickle file contains
1645
+ # the equations:
1646
+ self._checkpoint()
1647
+
1648
+ return self
1649
 
1650
  def refresh(self, checkpoint_file=None):
1651
  """