MilesCranmer commited on
Commit
70dcb83
1 Parent(s): aa16a1e

Add reset function for state saving.

Browse files
Files changed (1) hide show
  1. pysr/sr.py +11 -6
pysr/sr.py CHANGED
@@ -322,7 +322,7 @@ def _write_project_file(tmp_dir):
322
  SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
323
 
324
  [compat]
325
- SymbolicRegression = "0.7.1"
326
  julia = "1.5"
327
  """
328
 
@@ -640,7 +640,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
640
  self.equations = None
641
  self.params_hash = None
642
  self.raw_julia_state = None
643
- self.raw_julia_hof = None
644
 
645
  self.multioutput = None
646
  self.equation_file = equation_file
@@ -861,6 +860,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
861
  return [eq["torch_format"] for eq in best]
862
  return best["torch_format"]
863
 
 
 
 
 
 
 
864
  def _run(self, X, y, weights, variable_names):
865
  global already_ran
866
  global Main
@@ -1074,7 +1079,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1074
  "Warning: PySR options have changed since the last run. "
1075
  "This is experimental and may not work. "
1076
  "For example, if the operators change, or even their order,"
1077
- " the saved equations will be in the wrong format.",
 
 
1078
  )
1079
 
1080
  self.params_hash = cur_hash
@@ -1140,9 +1147,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1140
 
1141
  cprocs = 0 if multithreading else procs
1142
 
1143
- # Julia return value:
1144
- # state = (returnPops, hallOfFame)
1145
- self.raw_julia_state, self.raw_julia_hof = Main.EquationSearch(
1146
  Main.X,
1147
  Main.y,
1148
  weights=Main.weights,
 
322
  SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
323
 
324
  [compat]
325
+ SymbolicRegression = "0.7.2"
326
  julia = "1.5"
327
  """
328
 
 
640
  self.equations = None
641
  self.params_hash = None
642
  self.raw_julia_state = None
 
643
 
644
  self.multioutput = None
645
  self.equation_file = equation_file
 
860
  return [eq["torch_format"] for eq in best]
861
  return best["torch_format"]
862
 
863
+ def reset(self):
864
+ """Reset the search state."""
865
+ self.equations = None
866
+ self.params_hash = None
867
+ self.raw_julia_state = None
868
+
869
  def _run(self, X, y, weights, variable_names):
870
  global already_ran
871
  global Main
 
1079
  "Warning: PySR options have changed since the last run. "
1080
  "This is experimental and may not work. "
1081
  "For example, if the operators change, or even their order,"
1082
+ " the saved equations will be in the wrong format."
1083
+ "\n\n"
1084
+ "To reset the search state, run `.reset()`. "
1085
  )
1086
 
1087
  self.params_hash = cur_hash
 
1147
 
1148
  cprocs = 0 if multithreading else procs
1149
 
1150
+ self.raw_julia_state = Main.EquationSearch(
 
 
1151
  Main.X,
1152
  Main.y,
1153
  weights=Main.weights,