MilesCranmer commited on
Commit
40f498c
1 Parent(s): f06ee71

Remove misplaced parameter changes

Browse files
Files changed (1) hide show
  1. pysr/sr.py +15 -11
pysr/sr.py CHANGED
@@ -974,7 +974,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
974
  else:
975
  self.equation_file_ = self.equation_file
976
 
977
-
978
  def _validate_fit_params(self, X, y, Xresampled, variable_names):
979
  """
980
  Validates the parameters passed to the :term`fit` method.
@@ -1180,17 +1179,18 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1180
 
1181
  if self.cluster_manager is not None:
1182
  Main.eval(f"import ClusterManagers: addprocs_{self.cluster_manager}")
1183
- self.cluster_manager = Main.eval(f"addprocs_{self.cluster_manager}")
1184
-
1185
- self.julia_project, is_shared = _get_julia_project(self.julia_project)
1186
 
1187
  if not already_ran:
 
1188
  Main.eval("using Pkg")
1189
  io = "devnull" if self.update_verbosity == 0 else "stderr"
1190
  io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
1191
 
1192
  Main.eval(
1193
- f'Pkg.activate("{_escape_filename(self.julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
1194
  )
1195
  from julia.api import JuliaError
1196
 
@@ -1205,7 +1205,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1205
  else:
1206
  Main.eval(f"Pkg.instantiate({io_arg})")
1207
  except (JuliaError, RuntimeError) as e:
1208
- raise ImportError(import_error_string(self.julia_project)) from e
1209
  Main.eval("using SymbolicRegression")
1210
 
1211
  Main.plus = Main.eval("(+)")
@@ -1235,7 +1235,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1235
  nested_constraints_str += f"({inner_k}) => {inner_v}, "
1236
  nested_constraints_str += "), "
1237
  nested_constraints_str += ")"
1238
- self.nested_constraints = Main.eval(nested_constraints_str)
 
 
1239
 
1240
  # Parse dict into Julia Dict for complexities:
1241
  if self.complexity_of_operators is not None:
@@ -1243,7 +1245,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1243
  for k, v in self.complexity_of_operators.items():
1244
  complexity_of_operators_str += f"({k}) => {v}, "
1245
  complexity_of_operators_str += ")"
1246
- self.complexity_of_operators = Main.eval(complexity_of_operators_str)
 
 
1247
 
1248
  Main.custom_loss = Main.eval(self.loss)
1249
 
@@ -1269,10 +1273,10 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1269
  ),
1270
  bin_constraints=bin_constraints,
1271
  una_constraints=una_constraints,
1272
- complexity_of_operators=self.complexity_of_operators,
1273
  complexity_of_constants=self.complexity_of_constants,
1274
  complexity_of_variables=self.complexity_of_variables,
1275
- nested_constraints=self.nested_constraints,
1276
  loss=Main.custom_loss,
1277
  maxsize=int(self.maxsize),
1278
  hofFile=_escape_filename(self.equation_file_),
@@ -1344,7 +1348,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1344
  numprocs=int(cprocs),
1345
  multithreading=bool(self.multithreading),
1346
  saved_state=self.raw_julia_state_,
1347
- addprocs_function=self.cluster_manager,
1348
  )
1349
 
1350
  # Set attributes
 
974
  else:
975
  self.equation_file_ = self.equation_file
976
 
 
977
  def _validate_fit_params(self, X, y, Xresampled, variable_names):
978
  """
979
  Validates the parameters passed to the :term`fit` method.
 
1179
 
1180
  if self.cluster_manager is not None:
1181
  Main.eval(f"import ClusterManagers: addprocs_{self.cluster_manager}")
1182
+ cluster_manager = Main.eval(f"addprocs_{self.cluster_manager}")
1183
+ else:
1184
+ cluster_manager = None
1185
 
1186
  if not already_ran:
1187
+ julia_project, is_shared = _get_julia_project(self.julia_project)
1188
  Main.eval("using Pkg")
1189
  io = "devnull" if self.update_verbosity == 0 else "stderr"
1190
  io_arg = f"io={io}" if is_julia_version_greater_eq(Main, "1.6") else ""
1191
 
1192
  Main.eval(
1193
+ f'Pkg.activate("{_escape_filename(julia_project)}", shared = Bool({int(is_shared)}), {io_arg})'
1194
  )
1195
  from julia.api import JuliaError
1196
 
 
1205
  else:
1206
  Main.eval(f"Pkg.instantiate({io_arg})")
1207
  except (JuliaError, RuntimeError) as e:
1208
+ raise ImportError(import_error_string(julia_project)) from e
1209
  Main.eval("using SymbolicRegression")
1210
 
1211
  Main.plus = Main.eval("(+)")
 
1235
  nested_constraints_str += f"({inner_k}) => {inner_v}, "
1236
  nested_constraints_str += "), "
1237
  nested_constraints_str += ")"
1238
+ nested_constraints = Main.eval(nested_constraints_str)
1239
+ else:
1240
+ nested_constraints = None
1241
 
1242
  # Parse dict into Julia Dict for complexities:
1243
  if self.complexity_of_operators is not None:
 
1245
  for k, v in self.complexity_of_operators.items():
1246
  complexity_of_operators_str += f"({k}) => {v}, "
1247
  complexity_of_operators_str += ")"
1248
+ complexity_of_operators = Main.eval(complexity_of_operators_str)
1249
+ else:
1250
+ complexity_of_operators = None
1251
 
1252
  Main.custom_loss = Main.eval(self.loss)
1253
 
 
1273
  ),
1274
  bin_constraints=bin_constraints,
1275
  una_constraints=una_constraints,
1276
+ complexity_of_operators=complexity_of_operators,
1277
  complexity_of_constants=self.complexity_of_constants,
1278
  complexity_of_variables=self.complexity_of_variables,
1279
+ nested_constraints=nested_constraints,
1280
  loss=Main.custom_loss,
1281
  maxsize=int(self.maxsize),
1282
  hofFile=_escape_filename(self.equation_file_),
 
1348
  numprocs=int(cprocs),
1349
  multithreading=bool(self.multithreading),
1350
  saved_state=self.raw_julia_state_,
1351
+ addprocs_function=cluster_manager,
1352
  )
1353
 
1354
  # Set attributes