MilesCranmer commited on
Commit
44dcbea
·
1 Parent(s): 4c9fe98

Allow functional versions of early stop condition

Browse files
Files changed (1) hide show
  1. pysr/sr.py +8 -3
pysr/sr.py CHANGED
@@ -312,8 +312,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
312
  annealing : bool, default=True
313
  Whether to use annealing. You should (and it is default).
314
 
315
- early_stop_condition : float, default=None
316
- Stop the search early if this loss is reached.
 
 
 
317
 
318
  ncyclesperiteration : int, default=550
319
  Number of total mutations to run, per 10 samples of the
@@ -971,6 +974,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
971
 
972
  # 'Mutable' parameter validation
973
  buffer_available = "buffer" in sys.stdout.__dir__()
 
974
  modifiable_params = {
975
  "binary_operators": "+ * - /".split(" "),
976
  "unary_operators": [],
@@ -1308,6 +1312,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1308
  complexity_of_operators = Main.eval(complexity_of_operators_str)
1309
 
1310
  custom_loss = Main.eval(self.loss)
 
1311
 
1312
  mutationWeights = [
1313
  float(self.weight_mutate_constant),
@@ -1369,7 +1374,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1369
  crossoverProbability=self.crossover_probability,
1370
  skip_mutation_failures=self.skip_mutation_failures,
1371
  max_evals=self.max_evals,
1372
- earlyStopCondition=self.early_stop_condition,
1373
  seed=seed,
1374
  deterministic=self.deterministic,
1375
  )
 
312
  annealing : bool, default=True
313
  Whether to use annealing. You should (and it is default).
314
 
315
+ early_stop_condition : { float | str }, default=None
316
+ Stop the search early if this loss is reached. You may also
317
+ pass a string containing a Julia function which
318
+ takes a loss and complexity as input, for example:
319
+ `"f(loss, complexity) = (loss < 0.1) && (complexity < 10)"`.
320
 
321
  ncyclesperiteration : int, default=550
322
  Number of total mutations to run, per 10 samples of the
 
974
 
975
  # 'Mutable' parameter validation
976
  buffer_available = "buffer" in sys.stdout.__dir__()
977
+ # Params and their default values, if None is given:
978
  modifiable_params = {
979
  "binary_operators": "+ * - /".split(" "),
980
  "unary_operators": [],
 
1312
  complexity_of_operators = Main.eval(complexity_of_operators_str)
1313
 
1314
  custom_loss = Main.eval(self.loss)
1315
+ early_stop_condition = Main.eval(self.early_stop_condition)
1316
 
1317
  mutationWeights = [
1318
  float(self.weight_mutate_constant),
 
1374
  crossoverProbability=self.crossover_probability,
1375
  skip_mutation_failures=self.skip_mutation_failures,
1376
  max_evals=self.max_evals,
1377
+ earlyStopCondition=early_stop_condition,
1378
  seed=seed,
1379
  deterministic=self.deterministic,
1380
  )