MilesCranmer commited on
Commit
5c0ad55
1 Parent(s): b16d9ef

Allow loading from pickle file

Browse files
Files changed (1) hide show
  1. pysr/sr.py +24 -3
pysr/sr.py CHANGED
@@ -2061,9 +2061,9 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
2061
  def load(
2062
  equation_file,
2063
  *,
2064
- binary_operators,
2065
- unary_operators,
2066
- n_features_in,
2067
  feature_names_in=None,
2068
  selection_mask=None,
2069
  nout=1,
@@ -2097,12 +2097,33 @@ def load(
2097
 
2098
  pysr_kwargs : dict
2099
  Any other keyword arguments to initialize the PySRRegressor object.
 
2100
 
2101
  Returns
2102
  -------
2103
  model : PySRRegressor
2104
  The model with fitted equations.
2105
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2106
 
2107
  # TODO: copy .bkup file if exists.
2108
  model = PySRRegressor(
 
2061
  def load(
2062
  equation_file,
2063
  *,
2064
+ binary_operators=None,
2065
+ unary_operators=None,
2066
+ n_features_in=None,
2067
  feature_names_in=None,
2068
  selection_mask=None,
2069
  nout=1,
 
2097
 
2098
  pysr_kwargs : dict
2099
  Any other keyword arguments to initialize the PySRRegressor object.
2100
+ These will overwrite those stored in the pickle file.
2101
 
2102
  Returns
2103
  -------
2104
  model : PySRRegressor
2105
  The model with fitted equations.
2106
  """
2107
+ # Try to load model from <equation_file>.pkl
2108
+ print(f"Checking if {equation_file}.pkl exists...")
2109
+ if os.path.exists(str(equation_file) + ".pkl"):
2110
+ assert binary_operators is None
2111
+ assert unary_operators is None
2112
+ assert n_features_in is None
2113
+ with open(str(equation_file) + ".pkl", "rb") as f:
2114
+ model = pkl.load(f)
2115
+ model.set_params(**pysr_kwargs)
2116
+ model.refresh()
2117
+ return model
2118
+
2119
+ # Else, we re-create it.
2120
+ print(
2121
+ f"{equation_file}.pkl does not exist, "
2122
+ "so we must create the model from scratch."
2123
+ )
2124
+ assert binary_operators is not None
2125
+ assert unary_operators is not None
2126
+ assert n_features_in is not None
2127
 
2128
  # TODO: copy .bkup file if exists.
2129
  model = PySRRegressor(