MilesCranmer commited on
Commit
ed19905
1 Parent(s): bae75db

Start on state saving

Browse files
Files changed (1) hide show
  1. pysr/sr.py +23 -3
pysr/sr.py CHANGED
@@ -636,9 +636,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
636
 
637
  # Stored equations:
638
  self.equations = None
 
 
 
639
 
640
  self.multioutput = None
641
- self.raw_julia_output = None
642
  self.equation_file = equation_file
643
  self.n_features = None
644
  self.extra_sympy_mappings = extra_sympy_mappings
@@ -654,7 +656,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
654
  self.surface_parameters = [
655
  "model_selection",
656
  "multioutput",
657
- "raw_julia_output",
658
  "equation_file",
659
  "n_features",
660
  "extra_sympy_mappings",
@@ -1046,6 +1047,21 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1046
  float(weightDoNothing),
1047
  ]
1048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
  options = Main.Options(
1050
  binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
1051
  unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
@@ -1085,6 +1101,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1085
  optimizer_iterations=self.params["optimizer_iterations"],
1086
  perturbationFactor=self.params["perturbationFactor"],
1087
  annealing=self.params["annealing"],
 
1088
  )
1089
 
1090
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
@@ -1106,7 +1123,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1106
 
1107
  cprocs = 0 if multithreading else procs
1108
 
1109
- self.raw_julia_output = Main.EquationSearch(
 
 
1110
  Main.X,
1111
  Main.y,
1112
  weights=Main.weights,
@@ -1119,6 +1138,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1119
  options=options,
1120
  numprocs=int(cprocs),
1121
  multithreading=bool(multithreading),
 
1122
  )
1123
 
1124
  self.variable_names = variable_names
 
636
 
637
  # Stored equations:
638
  self.equations = None
639
+ self.params_hash = None
640
+ self.raw_julia_state = None
641
+ self.raw_julia_hof = None
642
 
643
  self.multioutput = None
 
644
  self.equation_file = equation_file
645
  self.n_features = None
646
  self.extra_sympy_mappings = extra_sympy_mappings
 
656
  self.surface_parameters = [
657
  "model_selection",
658
  "multioutput",
 
659
  "equation_file",
660
  "n_features",
661
  "extra_sympy_mappings",
 
1047
  float(weightDoNothing),
1048
  ]
1049
 
1050
+ all_params = {
1051
+ **{k: self.__getattribute__(k) for k in self.surface_parameters}
1052
+ ** self.params
1053
+ }
1054
+ if self.params_hash is not None:
1055
+ if hash(all_params) != self.params_hash:
1056
+ warnings.warn(
1057
+ "Warning: PySR options have changed since the last run. "
1058
+ "This is experimental and may not work. "
1059
+ "For example, if the operators change, or even their order,",
1060
+ " the saved equations will be in the wrong format."
1061
+ )
1062
+
1063
+ self.params_hash = hash(all_params)
1064
+
1065
  options = Main.Options(
1066
  binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
1067
  unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
 
1101
  optimizer_iterations=self.params["optimizer_iterations"],
1102
  perturbationFactor=self.params["perturbationFactor"],
1103
  annealing=self.params["annealing"],
1104
+ stateReturn=True, # Required for state saving.
1105
  )
1106
 
1107
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
 
1123
 
1124
  cprocs = 0 if multithreading else procs
1125
 
1126
+ # Julia return value:
1127
+ # state = (returnPops, hallOfFame)
1128
+ self.raw_julia_state, self.raw_julia_hof = Main.EquationSearch(
1129
  Main.X,
1130
  Main.y,
1131
  weights=Main.weights,
 
1138
  options=options,
1139
  numprocs=int(cprocs),
1140
  multithreading=bool(multithreading),
1141
+ saved_state=self.raw_julia_state,
1142
  )
1143
 
1144
  self.variable_names = variable_names