Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
3dff82f
1
Parent(s):
e7b4ea9
Make __init__ not modify parameters again
Browse files- pysr/sr.py +120 -126
pysr/sr.py
CHANGED
@@ -759,8 +759,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
759 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
760 |
)
|
761 |
|
762 |
-
self._process_params()
|
763 |
-
|
764 |
def __repr__(self):
|
765 |
"""
|
766 |
Prints all current equations fitted by the model.
|
@@ -865,105 +863,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
865 |
f"{self.model_selection} is not a valid model selection strategy."
|
866 |
)
|
867 |
|
868 |
-
def _process_params(self):
|
869 |
-
"""
|
870 |
-
Perform validation on the parameters defined in init for the
|
871 |
-
dataset specified in :term`fit`, and update them if necessary.
|
872 |
-
For example, this will change :param`binary_operators`
|
873 |
-
into `["+", "-", "*", "/"]` if `binary_operators` is `None`.
|
874 |
-
|
875 |
-
Raises
|
876 |
-
------
|
877 |
-
ValueError
|
878 |
-
Raised when on of the following occurs: `tournament_selection_n`
|
879 |
-
parameter is larger than `population_size`; `maxsize` is
|
880 |
-
less than 7; invalid `extra_jax_mappings` or
|
881 |
-
`extra_torch_mappings`; invalid optimizer algorithms.
|
882 |
-
|
883 |
-
"""
|
884 |
-
# Handle None values for instance parameters:
|
885 |
-
if self.binary_operators is None:
|
886 |
-
self.binary_operators = "+ * - /".split(" ")
|
887 |
-
if self.unary_operators is None:
|
888 |
-
self.unary_operators = []
|
889 |
-
if self.extra_sympy_mappings is None:
|
890 |
-
self.extra_sympy_mappings = {}
|
891 |
-
if self.constraints is None:
|
892 |
-
self.constraints = {}
|
893 |
-
if self.multithreading is None:
|
894 |
-
# Default is multithreading=True, unless explicitly set,
|
895 |
-
# or procs is set to 0 (serial mode).
|
896 |
-
self.multithreading = self.procs != 0 and self.cluster_manager is None
|
897 |
-
if self.update_verbosity is None:
|
898 |
-
self.update_verbosity = self.verbosity
|
899 |
-
if self.maxdepth is None:
|
900 |
-
self.maxdepth = self.maxsize
|
901 |
-
|
902 |
-
# Handle type conversion for instance parameters:
|
903 |
-
if isinstance(self.binary_operators, str):
|
904 |
-
self.binary_operators = [self.binary_operators]
|
905 |
-
if isinstance(self.unary_operators, str):
|
906 |
-
self.unary_operators = [self.unary_operators]
|
907 |
-
|
908 |
-
# Warn if instance parameters are not sensible values:
|
909 |
-
if self.batch_size < 1:
|
910 |
-
warnings.warn(
|
911 |
-
"Given :param`batch_size` must be greater than or equal to one. "
|
912 |
-
":param`batch_size` has been increased to equal one."
|
913 |
-
)
|
914 |
-
self.batch_size = 1
|
915 |
-
|
916 |
-
# Ensure instance parameters are allowable values:
|
917 |
-
# ValueError - Incompatible values
|
918 |
-
if self.tournament_selection_n > self.population_size:
|
919 |
-
raise ValueError(
|
920 |
-
"tournament_selection_n parameter must be smaller than population_size."
|
921 |
-
)
|
922 |
-
|
923 |
-
if self.maxsize > 40:
|
924 |
-
warnings.warn(
|
925 |
-
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
926 |
-
)
|
927 |
-
elif self.maxsize < 7:
|
928 |
-
raise ValueError("PySR requires a maxsize of at least 7")
|
929 |
-
|
930 |
-
if self.extra_jax_mappings is not None:
|
931 |
-
for value in self.extra_jax_mappings.values():
|
932 |
-
if not isinstance(value, str):
|
933 |
-
raise ValueError(
|
934 |
-
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
935 |
-
)
|
936 |
-
else:
|
937 |
-
self.extra_jax_mappings = {}
|
938 |
-
|
939 |
-
if self.extra_torch_mappings is not None:
|
940 |
-
for value in self.extra_jax_mappings.values():
|
941 |
-
if not callable(value):
|
942 |
-
raise ValueError(
|
943 |
-
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
944 |
-
)
|
945 |
-
else:
|
946 |
-
self.extra_torch_mappings = {}
|
947 |
-
|
948 |
-
# NotImplementedError - Values that could be supported at a later time
|
949 |
-
if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
|
950 |
-
raise NotImplementedError(
|
951 |
-
f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
|
952 |
-
)
|
953 |
-
|
954 |
-
# Handle presentation of the progress bar:
|
955 |
-
buffer_available = "buffer" in sys.stdout.__dir__()
|
956 |
-
if self.progress is not None:
|
957 |
-
if self.progress and not buffer_available:
|
958 |
-
warnings.warn(
|
959 |
-
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
960 |
-
)
|
961 |
-
self.progress = False
|
962 |
-
else:
|
963 |
-
self.progress = buffer_available
|
964 |
-
|
965 |
-
return self
|
966 |
-
|
967 |
def _setup_equation_file(self):
|
968 |
"""
|
969 |
Sets the full pathname of the equation file, using :param`tempdir` and
|
@@ -1016,6 +915,39 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1016 |
|
1017 |
"""
|
1018 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1019 |
if isinstance(X, pd.DataFrame):
|
1020 |
if variable_names:
|
1021 |
variable_names = None
|
@@ -1165,23 +1097,82 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1165 |
global already_ran
|
1166 |
global Main
|
1167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1168 |
# Start julia backend processes
|
1169 |
if Main is None:
|
1170 |
-
if
|
1171 |
os.environ["JULIA_NUM_THREADS"] = str(self.procs)
|
1172 |
|
1173 |
Main = init_julia()
|
1174 |
|
1175 |
-
if
|
1176 |
-
Main.eval(f"import ClusterManagers: addprocs_{
|
1177 |
-
cluster_manager = Main.eval(f"addprocs_{
|
1178 |
-
else:
|
1179 |
-
cluster_manager = None
|
1180 |
|
1181 |
if not already_ran:
|
1182 |
julia_project, is_shared = _get_julia_project(self.julia_project)
|
1183 |
Main.eval("using Pkg")
|
1184 |
-
io = "devnull" if
|
1185 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
1186 |
|
1187 |
Main.eval(
|
@@ -1211,39 +1202,35 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1211 |
|
1212 |
# TODO(mcranmer): These functions should be part of this class.
|
1213 |
binary_operators, unary_operators = _maybe_create_inline_operators(
|
1214 |
-
binary_operators=
|
1215 |
)
|
1216 |
constraints = _process_constraints(
|
1217 |
binary_operators=binary_operators,
|
1218 |
unary_operators=unary_operators,
|
1219 |
-
constraints=
|
1220 |
)
|
1221 |
|
1222 |
una_constraints = [constraints[op] for op in unary_operators]
|
1223 |
bin_constraints = [constraints[op] for op in binary_operators]
|
1224 |
|
1225 |
# Parse dict into Julia Dict for nested constraints::
|
1226 |
-
if
|
1227 |
nested_constraints_str = "Dict("
|
1228 |
-
for outer_k, outer_v in
|
1229 |
nested_constraints_str += f"({outer_k}) => Dict("
|
1230 |
for inner_k, inner_v in outer_v.items():
|
1231 |
nested_constraints_str += f"({inner_k}) => {inner_v}, "
|
1232 |
nested_constraints_str += "), "
|
1233 |
nested_constraints_str += ")"
|
1234 |
nested_constraints = Main.eval(nested_constraints_str)
|
1235 |
-
else:
|
1236 |
-
nested_constraints = None
|
1237 |
|
1238 |
# Parse dict into Julia Dict for complexities:
|
1239 |
-
if
|
1240 |
complexity_of_operators_str = "Dict("
|
1241 |
-
for k, v in
|
1242 |
complexity_of_operators_str += f"({k}) => {v}, "
|
1243 |
complexity_of_operators_str += ")"
|
1244 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
1245 |
-
else:
|
1246 |
-
complexity_of_operators = None
|
1247 |
|
1248 |
Main.custom_loss = Main.eval(self.loss)
|
1249 |
|
@@ -1274,14 +1261,14 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1274 |
hofFile=_escape_filename(self.equation_file_),
|
1275 |
npopulations=int(self.populations),
|
1276 |
batching=self.batching,
|
1277 |
-
batchSize=int(min([
|
1278 |
mutationWeights=mutationWeights,
|
1279 |
probPickFirst=self.tournament_selection_p,
|
1280 |
ns=self.tournament_selection_n,
|
1281 |
# These have the same name:
|
1282 |
parsimony=self.parsimony,
|
1283 |
alpha=self.alpha,
|
1284 |
-
maxdepth=
|
1285 |
fast_cycle=self.fast_cycle,
|
1286 |
migration=self.migration,
|
1287 |
hofMigration=self.hof_migration,
|
@@ -1302,7 +1289,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1302 |
perturbationFactor=self.perturbation_factor,
|
1303 |
annealing=self.annealing,
|
1304 |
stateReturn=True, # Required for state saving.
|
1305 |
-
progress=
|
1306 |
timeout_in_seconds=self.timeout_in_seconds,
|
1307 |
crossoverProbability=self.crossover_probability,
|
1308 |
skip_mutation_failures=self.skip_mutation_failures,
|
@@ -1313,6 +1300,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1313 |
# Convert data to desired precision
|
1314 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
|
1315 |
|
|
|
1316 |
Main.X = np.array(X, dtype=np_dtype).T
|
1317 |
if len(y.shape) == 1:
|
1318 |
Main.y = np.array(y, dtype=np_dtype)
|
@@ -1326,7 +1314,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1326 |
else:
|
1327 |
Main.weights = None
|
1328 |
|
1329 |
-
cprocs = 0 if
|
1330 |
|
1331 |
# Call to Julia backend.
|
1332 |
# See https://github.com/search?q=%22function+EquationSearch%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3ASymbolicRegression.jl+language%3AJulia&type=Code
|
@@ -1338,7 +1326,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1338 |
varMap=self.feature_names_in_.tolist(),
|
1339 |
options=options,
|
1340 |
numprocs=int(cprocs),
|
1341 |
-
multithreading=bool(
|
1342 |
saved_state=self.raw_julia_state_,
|
1343 |
addprocs_function=cluster_manager,
|
1344 |
)
|
@@ -1714,7 +1702,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1714 |
if self.output_torch_format:
|
1715 |
torch_format = []
|
1716 |
local_sympy_mappings = {
|
1717 |
-
**self.extra_sympy_mappings,
|
1718 |
**sympy_mappings,
|
1719 |
}
|
1720 |
|
@@ -1741,7 +1729,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1741 |
eqn,
|
1742 |
sympy_symbols,
|
1743 |
selection=self.selection_mask_,
|
1744 |
-
extra_jax_mappings=
|
|
|
|
|
1745 |
)
|
1746 |
jax_format.append({"callable": func, "parameters": params})
|
1747 |
|
@@ -1753,7 +1743,11 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1753 |
eqn,
|
1754 |
sympy_symbols,
|
1755 |
selection=self.selection_mask_,
|
1756 |
-
extra_torch_mappings=
|
|
|
|
|
|
|
|
|
1757 |
)
|
1758 |
torch_format.append(module)
|
1759 |
|
|
|
759 |
f"{k} is not a valid keyword argument for PySRRegressor"
|
760 |
)
|
761 |
|
|
|
|
|
762 |
def __repr__(self):
|
763 |
"""
|
764 |
Prints all current equations fitted by the model.
|
|
|
863 |
f"{self.model_selection} is not a valid model selection strategy."
|
864 |
)
|
865 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
866 |
def _setup_equation_file(self):
|
867 |
"""
|
868 |
Sets the full pathname of the equation file, using :param`tempdir` and
|
|
|
915 |
|
916 |
"""
|
917 |
|
918 |
+
# Ensure instance parameters are allowable values:
|
919 |
+
if self.tournament_selection_n > self.population_size:
|
920 |
+
raise ValueError(
|
921 |
+
"tournament_selection_n parameter must be smaller than population_size."
|
922 |
+
)
|
923 |
+
|
924 |
+
if self.maxsize > 40:
|
925 |
+
warnings.warn(
|
926 |
+
"Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
927 |
+
)
|
928 |
+
elif self.maxsize < 7:
|
929 |
+
raise ValueError("PySR requires a maxsize of at least 7")
|
930 |
+
|
931 |
+
if self.extra_jax_mappings is not None:
|
932 |
+
for value in self.extra_jax_mappings.values():
|
933 |
+
if not isinstance(value, str):
|
934 |
+
raise ValueError(
|
935 |
+
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
936 |
+
)
|
937 |
+
|
938 |
+
if self.extra_torch_mappings is not None:
|
939 |
+
for value in self.extra_jax_mappings.values():
|
940 |
+
if not callable(value):
|
941 |
+
raise ValueError(
|
942 |
+
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
943 |
+
)
|
944 |
+
|
945 |
+
# NotImplementedError - Values that could be supported at a later time
|
946 |
+
if self.optimizer_algorithm not in self.VALID_OPTIMIZER_ALGORITHMS:
|
947 |
+
raise NotImplementedError(
|
948 |
+
f"PySR currently only supports the following optimizer algorithms: {self.VALID_OPTIMIZER_ALGORITHMS}"
|
949 |
+
)
|
950 |
+
|
951 |
if isinstance(X, pd.DataFrame):
|
952 |
if variable_names:
|
953 |
variable_names = None
|
|
|
1097 |
global already_ran
|
1098 |
global Main
|
1099 |
|
1100 |
+
# These are the parameters which may be modified from the ones
|
1101 |
+
# specified in init, so we define them here locally:
|
1102 |
+
binary_operators = self.binary_operators
|
1103 |
+
unary_operators = self.unary_operators
|
1104 |
+
constraints = self.constraints
|
1105 |
+
nested_constraints = self.nested_constraints
|
1106 |
+
complexity_of_operators = self.complexity_of_operators
|
1107 |
+
multithreading = self.multithreading
|
1108 |
+
update_verbosity = self.update_verbosity
|
1109 |
+
maxdepth = self.maxdepth
|
1110 |
+
batch_size = self.batch_size
|
1111 |
+
progress = self.progress
|
1112 |
+
cluster_manager = self.cluster_manager
|
1113 |
+
|
1114 |
+
# TODO: Clean this up into a readable format, such that
|
1115 |
+
# a function call automatically configures each default.
|
1116 |
+
|
1117 |
+
# Deal with default values, and type conversions:
|
1118 |
+
if binary_operators is None:
|
1119 |
+
binary_operators = "+ * - /".split(" ")
|
1120 |
+
elif isinstance(binary_operators, str):
|
1121 |
+
binary_operators = [binary_operators]
|
1122 |
+
|
1123 |
+
if unary_operators is None:
|
1124 |
+
unary_operators = []
|
1125 |
+
elif isinstance(unary_operators, str):
|
1126 |
+
unary_operators = [unary_operators]
|
1127 |
+
|
1128 |
+
if constraints is None:
|
1129 |
+
constraints = {}
|
1130 |
+
|
1131 |
+
if multithreading is None:
|
1132 |
+
# Default is multithreading=True, unless explicitly set,
|
1133 |
+
# or procs is set to 0 (serial mode).
|
1134 |
+
multithreading = self.procs != 0 and cluster_manager is None
|
1135 |
+
|
1136 |
+
if update_verbosity is None:
|
1137 |
+
update_verbosity = self.verbosity
|
1138 |
+
|
1139 |
+
if maxdepth is None:
|
1140 |
+
maxdepth = self.maxsize
|
1141 |
+
|
1142 |
+
# Warn if instance parameters are not sensible values:
|
1143 |
+
if batch_size < 1:
|
1144 |
+
warnings.warn(
|
1145 |
+
"Given :param`batch_size` must be greater than or equal to one. "
|
1146 |
+
":param`batch_size` has been increased to equal one."
|
1147 |
+
)
|
1148 |
+
batch_size = 1
|
1149 |
+
|
1150 |
+
# Handle presentation of the progress bar:
|
1151 |
+
buffer_available = "buffer" in sys.stdout.__dir__()
|
1152 |
+
if progress is not None:
|
1153 |
+
if progress and not buffer_available:
|
1154 |
+
warnings.warn(
|
1155 |
+
"Note: it looks like you are running in Jupyter. The progress bar will be turned off."
|
1156 |
+
)
|
1157 |
+
progress = False
|
1158 |
+
else:
|
1159 |
+
progress = buffer_available
|
1160 |
+
|
1161 |
# Start julia backend processes
|
1162 |
if Main is None:
|
1163 |
+
if multithreading:
|
1164 |
os.environ["JULIA_NUM_THREADS"] = str(self.procs)
|
1165 |
|
1166 |
Main = init_julia()
|
1167 |
|
1168 |
+
if cluster_manager is not None:
|
1169 |
+
Main.eval(f"import ClusterManagers: addprocs_{cluster_manager}")
|
1170 |
+
cluster_manager = Main.eval(f"addprocs_{cluster_manager}")
|
|
|
|
|
1171 |
|
1172 |
if not already_ran:
|
1173 |
julia_project, is_shared = _get_julia_project(self.julia_project)
|
1174 |
Main.eval("using Pkg")
|
1175 |
+
io = "devnull" if update_verbosity == 0 else "stderr"
|
1176 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
1177 |
|
1178 |
Main.eval(
|
|
|
1202 |
|
1203 |
# TODO(mcranmer): These functions should be part of this class.
|
1204 |
binary_operators, unary_operators = _maybe_create_inline_operators(
|
1205 |
+
binary_operators=binary_operators, unary_operators=unary_operators
|
1206 |
)
|
1207 |
constraints = _process_constraints(
|
1208 |
binary_operators=binary_operators,
|
1209 |
unary_operators=unary_operators,
|
1210 |
+
constraints=constraints,
|
1211 |
)
|
1212 |
|
1213 |
una_constraints = [constraints[op] for op in unary_operators]
|
1214 |
bin_constraints = [constraints[op] for op in binary_operators]
|
1215 |
|
1216 |
# Parse dict into Julia Dict for nested constraints::
|
1217 |
+
if nested_constraints is not None:
|
1218 |
nested_constraints_str = "Dict("
|
1219 |
+
for outer_k, outer_v in nested_constraints.items():
|
1220 |
nested_constraints_str += f"({outer_k}) => Dict("
|
1221 |
for inner_k, inner_v in outer_v.items():
|
1222 |
nested_constraints_str += f"({inner_k}) => {inner_v}, "
|
1223 |
nested_constraints_str += "), "
|
1224 |
nested_constraints_str += ")"
|
1225 |
nested_constraints = Main.eval(nested_constraints_str)
|
|
|
|
|
1226 |
|
1227 |
# Parse dict into Julia Dict for complexities:
|
1228 |
+
if complexity_of_operators is not None:
|
1229 |
complexity_of_operators_str = "Dict("
|
1230 |
+
for k, v in complexity_of_operators.items():
|
1231 |
complexity_of_operators_str += f"({k}) => {v}, "
|
1232 |
complexity_of_operators_str += ")"
|
1233 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
|
|
|
|
1234 |
|
1235 |
Main.custom_loss = Main.eval(self.loss)
|
1236 |
|
|
|
1261 |
hofFile=_escape_filename(self.equation_file_),
|
1262 |
npopulations=int(self.populations),
|
1263 |
batching=self.batching,
|
1264 |
+
batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
|
1265 |
mutationWeights=mutationWeights,
|
1266 |
probPickFirst=self.tournament_selection_p,
|
1267 |
ns=self.tournament_selection_n,
|
1268 |
# These have the same name:
|
1269 |
parsimony=self.parsimony,
|
1270 |
alpha=self.alpha,
|
1271 |
+
maxdepth=maxdepth,
|
1272 |
fast_cycle=self.fast_cycle,
|
1273 |
migration=self.migration,
|
1274 |
hofMigration=self.hof_migration,
|
|
|
1289 |
perturbationFactor=self.perturbation_factor,
|
1290 |
annealing=self.annealing,
|
1291 |
stateReturn=True, # Required for state saving.
|
1292 |
+
progress=progress,
|
1293 |
timeout_in_seconds=self.timeout_in_seconds,
|
1294 |
crossoverProbability=self.crossover_probability,
|
1295 |
skip_mutation_failures=self.skip_mutation_failures,
|
|
|
1300 |
# Convert data to desired precision
|
1301 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
|
1302 |
|
1303 |
+
# This converts the data into a Julia array:
|
1304 |
Main.X = np.array(X, dtype=np_dtype).T
|
1305 |
if len(y.shape) == 1:
|
1306 |
Main.y = np.array(y, dtype=np_dtype)
|
|
|
1314 |
else:
|
1315 |
Main.weights = None
|
1316 |
|
1317 |
+
cprocs = 0 if multithreading else self.procs
|
1318 |
|
1319 |
# Call to Julia backend.
|
1320 |
# See https://github.com/search?q=%22function+EquationSearch%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3ASymbolicRegression.jl+language%3AJulia&type=Code
|
|
|
1326 |
varMap=self.feature_names_in_.tolist(),
|
1327 |
options=options,
|
1328 |
numprocs=int(cprocs),
|
1329 |
+
multithreading=bool(multithreading),
|
1330 |
saved_state=self.raw_julia_state_,
|
1331 |
addprocs_function=cluster_manager,
|
1332 |
)
|
|
|
1702 |
if self.output_torch_format:
|
1703 |
torch_format = []
|
1704 |
local_sympy_mappings = {
|
1705 |
+
**(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
|
1706 |
**sympy_mappings,
|
1707 |
}
|
1708 |
|
|
|
1729 |
eqn,
|
1730 |
sympy_symbols,
|
1731 |
selection=self.selection_mask_,
|
1732 |
+
extra_jax_mappings=(
|
1733 |
+
self.extra_jax_mappings if self.extra_jax_mappings else {}
|
1734 |
+
),
|
1735 |
)
|
1736 |
jax_format.append({"callable": func, "parameters": params})
|
1737 |
|
|
|
1743 |
eqn,
|
1744 |
sympy_symbols,
|
1745 |
selection=self.selection_mask_,
|
1746 |
+
extra_torch_mappings=(
|
1747 |
+
self.extra_torch_mappings
|
1748 |
+
if self.extra_torch_mappings
|
1749 |
+
else {}
|
1750 |
+
),
|
1751 |
)
|
1752 |
torch_format.append(module)
|
1753 |
|