Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
40f498c
1
Parent(s):
f06ee71
Remove misplaced parameter changes
Browse files- pysr/sr.py +15 -11
pysr/sr.py
CHANGED
@@ -974,7 +974,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
974 |
else:
|
975 |
self.equation_file_ = self.equation_file
|
976 |
|
977 |
-
|
978 |
def _validate_fit_params(self, X, y, Xresampled, variable_names):
|
979 |
"""
|
980 |
Validates the parameters passed to the :term`fit` method.
|
@@ -1180,17 +1179,18 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1180 |
|
1181 |
if self.cluster_manager is not None:
|
1182 |
Main.eval(f"import ClusterManagers: addprocs_{self.cluster_manager}")
|
1183 |
-
|
1184 |
-
|
1185 |
-
|
1186 |
|
1187 |
if not already_ran:
|
|
|
1188 |
Main.eval("using Pkg")
|
1189 |
io = "devnull" if self.update_verbosity == 0 else "stderr"
|
1190 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
1191 |
|
1192 |
Main.eval(
|
1193 |
-
f'Pkg.activate("{_escape_filename(
|
1194 |
)
|
1195 |
from julia.api import JuliaError
|
1196 |
|
@@ -1205,7 +1205,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1205 |
else:
|
1206 |
Main.eval(f"Pkg.instantiate({io_arg})")
|
1207 |
except (JuliaError, RuntimeError) as e:
|
1208 |
-
raise ImportError(import_error_string(
|
1209 |
Main.eval("using SymbolicRegression")
|
1210 |
|
1211 |
Main.plus = Main.eval("(+)")
|
@@ -1235,7 +1235,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1235 |
nested_constraints_str += f"({inner_k}) => {inner_v}, "
|
1236 |
nested_constraints_str += "), "
|
1237 |
nested_constraints_str += ")"
|
1238 |
-
|
|
|
|
|
1239 |
|
1240 |
# Parse dict into Julia Dict for complexities:
|
1241 |
if self.complexity_of_operators is not None:
|
@@ -1243,7 +1245,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1243 |
for k, v in self.complexity_of_operators.items():
|
1244 |
complexity_of_operators_str += f"({k}) => {v}, "
|
1245 |
complexity_of_operators_str += ")"
|
1246 |
-
|
|
|
|
|
1247 |
|
1248 |
Main.custom_loss = Main.eval(self.loss)
|
1249 |
|
@@ -1269,10 +1273,10 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1269 |
),
|
1270 |
bin_constraints=bin_constraints,
|
1271 |
una_constraints=una_constraints,
|
1272 |
-
complexity_of_operators=
|
1273 |
complexity_of_constants=self.complexity_of_constants,
|
1274 |
complexity_of_variables=self.complexity_of_variables,
|
1275 |
-
nested_constraints=
|
1276 |
loss=Main.custom_loss,
|
1277 |
maxsize=int(self.maxsize),
|
1278 |
hofFile=_escape_filename(self.equation_file_),
|
@@ -1344,7 +1348,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1344 |
numprocs=int(cprocs),
|
1345 |
multithreading=bool(self.multithreading),
|
1346 |
saved_state=self.raw_julia_state_,
|
1347 |
-
addprocs_function=
|
1348 |
)
|
1349 |
|
1350 |
# Set attributes
|
|
|
974 |
else:
|
975 |
self.equation_file_ = self.equation_file
|
976 |
|
|
|
977 |
def _validate_fit_params(self, X, y, Xresampled, variable_names):
|
978 |
"""
|
979 |
Validates the parameters passed to the :term`fit` method.
|
|
|
1179 |
|
1180 |
if self.cluster_manager is not None:
|
1181 |
Main.eval(f"import ClusterManagers: addprocs_{self.cluster_manager}")
|
1182 |
+
cluster_manager = Main.eval(f"addprocs_{self.cluster_manager}")
|
1183 |
+
else:
|
1184 |
+
cluster_manager = None
|
1185 |
|
1186 |
if not already_ran:
|
1187 |
+
julia_project, is_shared = _get_julia_project(self.julia_project)
|
1188 |
Main.eval("using Pkg")
|
1189 |
io = "devnull" if self.update_verbosity == 0 else "stderr"
|
1190 |
io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
|
1191 |
|
1192 |
Main.eval(
|
1193 |
+
f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
|
1194 |
)
|
1195 |
from julia.api import JuliaError
|
1196 |
|
|
|
1205 |
else:
|
1206 |
Main.eval(f"Pkg.instantiate({io_arg})")
|
1207 |
except (JuliaError, RuntimeError) as e:
|
1208 |
+
raise ImportError(import_error_string(julia_project)) from e
|
1209 |
Main.eval("using SymbolicRegression")
|
1210 |
|
1211 |
Main.plus = Main.eval("(+)")
|
|
|
1235 |
nested_constraints_str += f"({inner_k}) => {inner_v}, "
|
1236 |
nested_constraints_str += "), "
|
1237 |
nested_constraints_str += ")"
|
1238 |
+
nested_constraints = Main.eval(nested_constraints_str)
|
1239 |
+
else:
|
1240 |
+
nested_constraints = None
|
1241 |
|
1242 |
# Parse dict into Julia Dict for complexities:
|
1243 |
if self.complexity_of_operators is not None:
|
|
|
1245 |
for k, v in self.complexity_of_operators.items():
|
1246 |
complexity_of_operators_str += f"({k}) => {v}, "
|
1247 |
complexity_of_operators_str += ")"
|
1248 |
+
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
1249 |
+
else:
|
1250 |
+
complexity_of_operators = None
|
1251 |
|
1252 |
Main.custom_loss = Main.eval(self.loss)
|
1253 |
|
|
|
1273 |
),
|
1274 |
bin_constraints=bin_constraints,
|
1275 |
una_constraints=una_constraints,
|
1276 |
+
complexity_of_operators=complexity_of_operators,
|
1277 |
complexity_of_constants=self.complexity_of_constants,
|
1278 |
complexity_of_variables=self.complexity_of_variables,
|
1279 |
+
nested_constraints=nested_constraints,
|
1280 |
loss=Main.custom_loss,
|
1281 |
maxsize=int(self.maxsize),
|
1282 |
hofFile=_escape_filename(self.equation_file_),
|
|
|
1348 |
numprocs=int(cprocs),
|
1349 |
multithreading=bool(self.multithreading),
|
1350 |
saved_state=self.raw_julia_state_,
|
1351 |
+
addprocs_function=cluster_manager,
|
1352 |
)
|
1353 |
|
1354 |
# Set attributes
|