Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
ed19905
1
Parent(s):
bae75db
Start on state saving
Browse files- pysr/sr.py +23 -3
pysr/sr.py
CHANGED
@@ -636,9 +636,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
636 |
|
637 |
# Stored equations:
|
638 |
self.equations = None
|
|
|
|
|
|
|
639 |
|
640 |
self.multioutput = None
|
641 |
-
self.raw_julia_output = None
|
642 |
self.equation_file = equation_file
|
643 |
self.n_features = None
|
644 |
self.extra_sympy_mappings = extra_sympy_mappings
|
@@ -654,7 +656,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
654 |
self.surface_parameters = [
|
655 |
"model_selection",
|
656 |
"multioutput",
|
657 |
-
"raw_julia_output",
|
658 |
"equation_file",
|
659 |
"n_features",
|
660 |
"extra_sympy_mappings",
|
@@ -1046,6 +1047,21 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1046 |
float(weightDoNothing),
|
1047 |
]
|
1048 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1049 |
options = Main.Options(
|
1050 |
binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
|
1051 |
unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
|
@@ -1085,6 +1101,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1085 |
optimizer_iterations=self.params["optimizer_iterations"],
|
1086 |
perturbationFactor=self.params["perturbationFactor"],
|
1087 |
annealing=self.params["annealing"],
|
|
|
1088 |
)
|
1089 |
|
1090 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
@@ -1106,7 +1123,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1106 |
|
1107 |
cprocs = 0 if multithreading else procs
|
1108 |
|
1109 |
-
|
|
|
|
|
1110 |
Main.X,
|
1111 |
Main.y,
|
1112 |
weights=Main.weights,
|
@@ -1119,6 +1138,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1119 |
options=options,
|
1120 |
numprocs=int(cprocs),
|
1121 |
multithreading=bool(multithreading),
|
|
|
1122 |
)
|
1123 |
|
1124 |
self.variable_names = variable_names
|
|
|
636 |
|
637 |
# Stored equations:
|
638 |
self.equations = None
|
639 |
+
self.params_hash = None
|
640 |
+
self.raw_julia_state = None
|
641 |
+
self.raw_julia_hof = None
|
642 |
|
643 |
self.multioutput = None
|
|
|
644 |
self.equation_file = equation_file
|
645 |
self.n_features = None
|
646 |
self.extra_sympy_mappings = extra_sympy_mappings
|
|
|
656 |
self.surface_parameters = [
|
657 |
"model_selection",
|
658 |
"multioutput",
|
|
|
659 |
"equation_file",
|
660 |
"n_features",
|
661 |
"extra_sympy_mappings",
|
|
|
1047 |
float(weightDoNothing),
|
1048 |
]
|
1049 |
|
1050 |
+
all_params = {
|
1051 |
+
**{k: self.__getattribute__(k) for k in self.surface_parameters}
|
1052 |
+
** self.params
|
1053 |
+
}
|
1054 |
+
if self.params_hash is not None:
|
1055 |
+
if hash(all_params) != self.params_hash:
|
1056 |
+
warnings.warn(
|
1057 |
+
"Warning: PySR options have changed since the last run. "
|
1058 |
+
"This is experimental and may not work. "
|
1059 |
+
"For example, if the operators change, or even their order,",
|
1060 |
+
" the saved equations will be in the wrong format."
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
self.params_hash = hash(all_params)
|
1064 |
+
|
1065 |
options = Main.Options(
|
1066 |
binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
|
1067 |
unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
|
|
|
1101 |
optimizer_iterations=self.params["optimizer_iterations"],
|
1102 |
perturbationFactor=self.params["perturbationFactor"],
|
1103 |
annealing=self.params["annealing"],
|
1104 |
+
stateReturn=True, # Required for state saving.
|
1105 |
)
|
1106 |
|
1107 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
|
|
1123 |
|
1124 |
cprocs = 0 if multithreading else procs
|
1125 |
|
1126 |
+
# Julia return value:
|
1127 |
+
# state = (returnPops, hallOfFame)
|
1128 |
+
self.raw_julia_state, self.raw_julia_hof = Main.EquationSearch(
|
1129 |
Main.X,
|
1130 |
Main.y,
|
1131 |
weights=Main.weights,
|
|
|
1138 |
options=options,
|
1139 |
numprocs=int(cprocs),
|
1140 |
multithreading=bool(multithreading),
|
1141 |
+
saved_state=self.raw_julia_state,
|
1142 |
)
|
1143 |
|
1144 |
self.variable_names = variable_names
|