MilesCranmer commited on
Commit
34f4e3f
1 Parent(s): f5577ea

Change model load to classmethod

Browse files
Files changed (3) hide show
  1. pysr/__init__.py +0 -1
  2. pysr/sr.py +113 -104
  3. test/test.py +4 -4
pysr/__init__.py CHANGED
@@ -6,7 +6,6 @@ from .sr import (
6
  best_tex,
7
  best_callable,
8
  best_row,
9
- load,
10
  )
11
  from .julia_helpers import install
12
  from .feynman_problems import Problem, FeynmanProblem
 
6
  best_tex,
7
  best_callable,
8
  best_row,
 
9
  )
10
  from .julia_helpers import install
11
  from .feynman_problems import Problem, FeynmanProblem
pysr/sr.py CHANGED
@@ -810,6 +810,119 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
810
  f"{k} is not a valid keyword argument for PySRRegressor."
811
  )
812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  def __repr__(self):
814
  """
815
  Prints all current equations fitted by the model.
@@ -2091,107 +2204,3 @@ def _csv_filename_to_pkl_filename(csv_filename) -> str:
2091
  pkl_basename = base + ".pkl"
2092
 
2093
  return os.path.join(dirname, pkl_basename)
2094
-
2095
-
2096
- def load(
2097
- equation_file,
2098
- *,
2099
- binary_operators=None,
2100
- unary_operators=None,
2101
- n_features_in=None,
2102
- feature_names_in=None,
2103
- selection_mask=None,
2104
- nout=1,
2105
- **pysr_kwargs,
2106
- ):
2107
- """
2108
- Create a model from equations stored as a csv file
2109
-
2110
- Parameters
2111
- ----------
2112
- equation_file : str
2113
- Path to a csv file containing equations, or a pickle file
2114
- containing the model.
2115
-
2116
- binary_operators : list[str], default=["+", "-", "*", "/"]
2117
- The same binary operators used when creating the model.
2118
-
2119
- unary_operators : list[str], default=None
2120
- The same unary operators used when creating the model.
2121
-
2122
- n_features_in : int
2123
- Number of features passed to the model.
2124
-
2125
- feature_names_in : list[str], default=None
2126
- Names of the features passed to the model.
2127
-
2128
- selection_mask : list[bool], default=None
2129
- If using select_k_features, you must pass `model.selection_mask_` here.
2130
-
2131
- nout : int, default=1
2132
- Number of outputs of the model.
2133
-
2134
- pysr_kwargs : dict
2135
- Any other keyword arguments to initialize the PySRRegressor object.
2136
- These will overwrite those stored in the pickle file.
2137
-
2138
- Returns
2139
- -------
2140
- model : PySRRegressor
2141
- The model with fitted equations.
2142
- """
2143
- if os.path.splitext(equation_file)[1] != ".pkl":
2144
- pkl_filename = _csv_filename_to_pkl_filename(equation_file)
2145
- else:
2146
- pkl_filename = equation_file
2147
-
2148
- # Try to load model from <equation_file>.pkl
2149
- print(f"Checking if {pkl_filename} exists...")
2150
- if os.path.exists(pkl_filename):
2151
- print(f"Loading model from {pkl_filename}")
2152
- assert binary_operators is None
2153
- assert unary_operators is None
2154
- assert n_features_in is None
2155
- with open(pkl_filename, "rb") as f:
2156
- model = pkl.load(f)
2157
- # Update any parameters if necessary, such as
2158
- # extra_sympy_mappings:
2159
- model.set_params(**pysr_kwargs)
2160
- if "equations_" not in model.__dict__ or model.equations_ is None:
2161
- model.refresh()
2162
-
2163
- return model
2164
-
2165
- # Else, we re-create it.
2166
- print(
2167
- f"{equation_file} does not exist, " "so we must create the model from scratch."
2168
- )
2169
- assert binary_operators is not None
2170
- assert unary_operators is not None
2171
- assert n_features_in is not None
2172
-
2173
- # TODO: copy .bkup file if exists.
2174
- model = PySRRegressor(
2175
- equation_file=equation_file,
2176
- binary_operators=binary_operators,
2177
- unary_operators=unary_operators,
2178
- **pysr_kwargs,
2179
- )
2180
-
2181
- model.nout_ = nout
2182
- model.n_features_in_ = n_features_in
2183
-
2184
- if feature_names_in is None:
2185
- model.feature_names_in_ = [f"x{i}" for i in range(n_features_in)]
2186
- else:
2187
- assert len(feature_names_in) == n_features_in
2188
- model.feature_names_in_ = feature_names_in
2189
-
2190
- if selection_mask is None:
2191
- model.selection_mask_ = np.ones(n_features_in, dtype=bool)
2192
- else:
2193
- model.selection_mask_ = selection_mask
2194
-
2195
- model.refresh(checkpoint_file=equation_file)
2196
-
2197
- return model
 
810
  f"{k} is not a valid keyword argument for PySRRegressor."
811
  )
812
 
813
+ @classmethod
814
+ def from_file(
815
+ cls,
816
+ equation_file,
817
+ *,
818
+ binary_operators=None,
819
+ unary_operators=None,
820
+ n_features_in=None,
821
+ feature_names_in=None,
822
+ selection_mask=None,
823
+ nout=1,
824
+ **pysr_kwargs,
825
+ ):
826
+ """
827
+ Create a model from a saved model checkpoint or equation file.
828
+
829
+ Parameters
830
+ ----------
831
+ equation_file : str
832
+ Path to a pickle file containing a saved model, or a csv file
833
+ containing equations.
834
+
835
+ binary_operators : list[str]
836
+ The same binary operators used when creating the model.
837
+ Not needed if loading from a pickle file.
838
+
839
+ unary_operators : list[str]
840
+ The same unary operators used when creating the model.
841
+ Not needed if loading from a pickle file.
842
+
843
+ n_features_in : int
844
+ Number of features passed to the model.
845
+ Not needed if loading from a pickle file.
846
+
847
+ feature_names_in : list[str]
848
+ Names of the features passed to the model.
849
+ Not needed if loading from a pickle file.
850
+
851
+ selection_mask : list[bool]
852
+ If using select_k_features, you must pass `model.selection_mask_` here.
853
+ Not needed if loading from a pickle file.
854
+
855
+ nout : int, default=1
856
+ Number of outputs of the model.
857
+ Not needed if loading from a pickle file.
858
+
859
+ pysr_kwargs : dict
860
+ Any other keyword arguments to initialize the PySRRegressor object.
861
+ These will overwrite those stored in the pickle file.
862
+ Not needed if loading from a pickle file.
863
+
864
+ Returns
865
+ -------
866
+ model : PySRRegressor
867
+ The model with fitted equations.
868
+ """
869
+ if os.path.splitext(equation_file)[1] != ".pkl":
870
+ pkl_filename = _csv_filename_to_pkl_filename(equation_file)
871
+ else:
872
+ pkl_filename = equation_file
873
+
874
+ # Try to load model from <equation_file>.pkl
875
+ print(f"Checking if {pkl_filename} exists...")
876
+ if os.path.exists(pkl_filename):
877
+ print(f"Loading model from {pkl_filename}")
878
+ assert binary_operators is None
879
+ assert unary_operators is None
880
+ assert n_features_in is None
881
+ with open(pkl_filename, "rb") as f:
882
+ model = pkl.load(f)
883
+ # Update any parameters if necessary, such as
884
+ # extra_sympy_mappings:
885
+ model.set_params(**pysr_kwargs)
886
+ if "equations_" not in model.__dict__ or model.equations_ is None:
887
+ model.refresh()
888
+
889
+ return model
890
+
891
+ # Else, we re-create it.
892
+ print(
893
+ f"{equation_file} does not exist, "
894
+ "so we must create the model from scratch."
895
+ )
896
+ assert binary_operators is not None
897
+ assert unary_operators is not None
898
+ assert n_features_in is not None
899
+
900
+ # TODO: copy .bkup file if exists.
901
+ model = cls(
902
+ equation_file=equation_file,
903
+ binary_operators=binary_operators,
904
+ unary_operators=unary_operators,
905
+ **pysr_kwargs,
906
+ )
907
+
908
+ model.nout_ = nout
909
+ model.n_features_in_ = n_features_in
910
+
911
+ if feature_names_in is None:
912
+ model.feature_names_in_ = [f"x{i}" for i in range(n_features_in)]
913
+ else:
914
+ assert len(feature_names_in) == n_features_in
915
+ model.feature_names_in_ = feature_names_in
916
+
917
+ if selection_mask is None:
918
+ model.selection_mask_ = np.ones(n_features_in, dtype=bool)
919
+ else:
920
+ model.selection_mask_ = selection_mask
921
+
922
+ model.refresh(checkpoint_file=equation_file)
923
+
924
+ return model
925
+
926
  def __repr__(self):
927
  """
928
  Prints all current equations fitted by the model.
 
2204
  pkl_basename = base + ".pkl"
2205
 
2206
  return os.path.join(dirname, pkl_basename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test/test.py CHANGED
@@ -4,7 +4,7 @@ import inspect
4
  import unittest
5
  import numpy as np
6
  from sklearn import model_selection
7
- from pysr import PySRRegressor, load
8
  from pysr.sr import (
9
  run_feature_selection,
10
  _handle_feature_selection,
@@ -300,7 +300,7 @@ class TestPipeline(unittest.TestCase):
300
  equation_filename = str(rand_dir / "equation.csv")
301
  with open(equation_filename + (".bkup" if from_backup else ""), "w") as f:
302
  f.write(csv_file_data)
303
- model = load(
304
  equation_filename,
305
  n_features_in=5,
306
  feature_names_in=["f0", "f1", "f2", "f3", "f4"],
@@ -334,7 +334,7 @@ class TestPipeline(unittest.TestCase):
334
 
335
  # lambda functions are removed from the pickling, so we need
336
  # to pass it during the loading:
337
- model2 = load(
338
  model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
339
  )
340
 
@@ -346,7 +346,7 @@ class TestPipeline(unittest.TestCase):
346
  os.remove(file_to_delete)
347
 
348
  pickle_file = rand_dir / "equations.pkl"
349
- model3 = load(
350
  model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
351
  )
352
  np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
 
4
  import unittest
5
  import numpy as np
6
  from sklearn import model_selection
7
+ from pysr import PySRRegressor
8
  from pysr.sr import (
9
  run_feature_selection,
10
  _handle_feature_selection,
 
300
  equation_filename = str(rand_dir / "equation.csv")
301
  with open(equation_filename + (".bkup" if from_backup else ""), "w") as f:
302
  f.write(csv_file_data)
303
+ model = PySRRegressor.from_file(
304
  equation_filename,
305
  n_features_in=5,
306
  feature_names_in=["f0", "f1", "f2", "f3", "f4"],
 
334
 
335
  # lambda functions are removed from the pickling, so we need
336
  # to pass it during the loading:
337
+ model2 = PySRRegressor.from_file(
338
  model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
339
  )
340
 
 
346
  os.remove(file_to_delete)
347
 
348
  pickle_file = rand_dir / "equations.pkl"
349
+ model3 = PySRRegressor.from_file(
350
  model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
351
  )
352
  np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))