MilesCranmer commited on
Commit
73aff8b
·
1 Parent(s): ab66141

Add early_stop_condition to stop earlier

Browse files
Files changed (1) hide show
  1. pysr/sr.py +6 -1
pysr/sr.py CHANGED
@@ -420,6 +420,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
420
  cluster_manager=None,
421
  skip_mutation_failures=True,
422
  max_evals=None,
 
423
  # To support deprecated kwargs:
424
  **kwargs,
425
  ):
@@ -562,6 +563,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
562
  :type skip_mutation_failures: bool
563
  :param max_evals: Limits the total number of evaluations of expressions to this number.
564
  :type max_evals: int
 
 
565
  :param kwargs: Supports deprecated keyword arguments. Other arguments will result
566
  in an error
567
  :type kwargs: dict
@@ -749,6 +752,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
749
  cluster_manager=cluster_manager,
750
  skip_mutation_failures=skip_mutation_failures,
751
  max_evals=max_evals,
 
752
  ),
753
  }
754
 
@@ -1313,8 +1317,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1313
  progress=self.params["progress"],
1314
  timeout_in_seconds=self.params["timeout_in_seconds"],
1315
  crossoverProbability=self.params["crossover_probability"],
1316
- max_evals=self.params["max_evals"],
1317
  skip_mutation_failures=self.params["skip_mutation_failures"],
 
 
1318
  )
1319
 
1320
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
 
420
  cluster_manager=None,
421
  skip_mutation_failures=True,
422
  max_evals=None,
423
+ early_stop_condition=None,
424
  # To support deprecated kwargs:
425
  **kwargs,
426
  ):
 
563
  :type skip_mutation_failures: bool
564
  :param max_evals: Limits the total number of evaluations of expressions to this number.
565
  :type max_evals: int
566
+ :param early_stop_condition: Stop the search early if this loss is reached.
567
+ :type early_stop_condition: float
568
  :param kwargs: Supports deprecated keyword arguments. Other arguments will result
569
  in an error
570
  :type kwargs: dict
 
752
  cluster_manager=cluster_manager,
753
  skip_mutation_failures=skip_mutation_failures,
754
  max_evals=max_evals,
755
+ early_stop_condition=early_stop_condition,
756
  ),
757
  }
758
 
 
1317
  progress=self.params["progress"],
1318
  timeout_in_seconds=self.params["timeout_in_seconds"],
1319
  crossoverProbability=self.params["crossover_probability"],
 
1320
  skip_mutation_failures=self.params["skip_mutation_failures"],
1321
+ max_evals=self.params["max_evals"],
1322
+ earlyStopCondition=self.params["early_stop_condition"],
1323
  )
1324
 
1325
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[