MilesCranmer commited on
Commit
7a5a9a0
1 Parent(s): f07f6e6

Clean up mutation_weights setting

Browse files
Files changed (1) hide show
  1. pysr/sr.py +14 -11
pysr/sr.py CHANGED
@@ -1314,16 +1314,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1314
  custom_loss = Main.eval(self.loss)
1315
  early_stop_condition = Main.eval(str(self.early_stop_condition))
1316
 
1317
- mutationWeights = [
1318
- float(self.weight_mutate_constant),
1319
- float(self.weight_mutate_operator),
1320
- float(self.weight_add_node),
1321
- float(self.weight_insert_node),
1322
- float(self.weight_delete_node),
1323
- float(self.weight_simplify),
1324
- float(self.weight_randomize),
1325
- float(self.weight_do_nothing),
1326
- ]
 
 
 
1327
 
1328
  # Call to Julia backend.
1329
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
@@ -1342,7 +1345,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1342
  npopulations=int(self.populations),
1343
  batching=self.batching,
1344
  batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
1345
- mutationWeights=mutationWeights,
1346
  probPickFirst=self.tournament_selection_p,
1347
  ns=self.tournament_selection_n,
1348
  # These have the same name:
 
1314
  custom_loss = Main.eval(self.loss)
1315
  early_stop_condition = Main.eval(str(self.early_stop_condition))
1316
 
1317
+ mutation_weights = np.array(
1318
+ [
1319
+ self.weight_mutate_constant,
1320
+ self.weight_mutate_operator,
1321
+ self.weight_add_node,
1322
+ self.weight_insert_node,
1323
+ self.weight_delete_node,
1324
+ self.weight_simplify,
1325
+ self.weight_randomize,
1326
+ self.weight_do_nothing,
1327
+ ],
1328
+ dtype=float,
1329
+ )
1330
 
1331
  # Call to Julia backend.
1332
  # See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
 
1345
  npopulations=int(self.populations),
1346
  batching=self.batching,
1347
  batchSize=int(min([batch_size, len(X)]) if self.batching else len(X)),
1348
+ mutationWeights=mutation_weights,
1349
  probPickFirst=self.tournament_selection_p,
1350
  ns=self.tournament_selection_n,
1351
  # These have the same name: