MilesCranmer commited on
Commit
9335949
2 Parent(s): fc75036 73aff8b

Merge pull request #134 from MilesCranmer/max_evals

Browse files

New exit strategies: max_evals and early_stop_condition

Files changed (2) hide show
  1. pysr/sr.py +10 -5
  2. pysr/version.py +2 -2
pysr/sr.py CHANGED
@@ -418,8 +418,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
418
  precision=32,
419
  multithreading=None,
420
  cluster_manager=None,
421
- use_symbolic_utils=False,
422
  skip_mutation_failures=True,
 
 
423
  # To support deprecated kwargs:
424
  **kwargs,
425
  ):
@@ -558,10 +559,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
558
  :type tournament_selection_p: float
559
  :param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
560
  :type precision: int
561
- :param use_symbolic_utils: Whether to use SymbolicUtils during simplification.
562
- :type use_symbolic_utils: bool
563
  :param skip_mutation_failures: Whether to skip mutation and crossover failures, rather than simply re-sampling the current member.
564
  :type skip_mutation_failures: bool
 
 
 
 
565
  :param kwargs: Supports deprecated keyword arguments. Other arguments will result
566
  in an error
567
  :type kwargs: dict
@@ -747,8 +750,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
747
  precision=precision,
748
  multithreading=multithreading,
749
  cluster_manager=cluster_manager,
750
- use_symbolic_utils=use_symbolic_utils,
751
  skip_mutation_failures=skip_mutation_failures,
 
 
752
  ),
753
  }
754
 
@@ -1310,11 +1314,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
1310
  perturbationFactor=self.params["perturbation_factor"],
1311
  annealing=self.params["annealing"],
1312
  stateReturn=True, # Required for state saving.
1313
- use_symbolic_utils=self.params["use_symbolic_utils"],
1314
  progress=self.params["progress"],
1315
  timeout_in_seconds=self.params["timeout_in_seconds"],
1316
  crossoverProbability=self.params["crossover_probability"],
1317
  skip_mutation_failures=self.params["skip_mutation_failures"],
 
 
1318
  )
1319
 
1320
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
 
418
  precision=32,
419
  multithreading=None,
420
  cluster_manager=None,
 
421
  skip_mutation_failures=True,
422
+ max_evals=None,
423
+ early_stop_condition=None,
424
  # To support deprecated kwargs:
425
  **kwargs,
426
  ):
 
559
  :type tournament_selection_p: float
560
  :param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
561
  :type precision: int
 
 
562
  :param skip_mutation_failures: Whether to skip mutation and crossover failures, rather than simply re-sampling the current member.
563
  :type skip_mutation_failures: bool
564
+ :param max_evals: Limits the total number of evaluations of expressions to this number.
565
+ :type max_evals: int
566
+ :param early_stop_condition: Stop the search early if this loss is reached.
567
+ :type early_stop_condition: float
568
  :param kwargs: Supports deprecated keyword arguments. Other arguments will result
569
  in an error
570
  :type kwargs: dict
 
750
  precision=precision,
751
  multithreading=multithreading,
752
  cluster_manager=cluster_manager,
 
753
  skip_mutation_failures=skip_mutation_failures,
754
+ max_evals=max_evals,
755
+ early_stop_condition=early_stop_condition,
756
  ),
757
  }
758
 
 
1314
  perturbationFactor=self.params["perturbation_factor"],
1315
  annealing=self.params["annealing"],
1316
  stateReturn=True, # Required for state saving.
 
1317
  progress=self.params["progress"],
1318
  timeout_in_seconds=self.params["timeout_in_seconds"],
1319
  crossoverProbability=self.params["crossover_probability"],
1320
  skip_mutation_failures=self.params["skip_mutation_failures"],
1321
+ max_evals=self.params["max_evals"],
1322
+ earlyStopCondition=self.params["early_stop_condition"],
1323
  )
1324
 
1325
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
pysr/version.py CHANGED
@@ -1,2 +1,2 @@
1
- __version__ = "0.8.1"
2
- __symbolic_regression_jl_version__ = "0.8.7"
 
1
+ __version__ = "0.8.2"
2
+ __symbolic_regression_jl_version__ = "0.9.1"