Spaces:
Running
Running
MilesCranmer
commited on
Merge pull request #134 from MilesCranmer/max_evals
Browse filesNew exit strategies: max_evals and early_stop_condition
- pysr/sr.py +10 -5
- pysr/version.py +2 -2
pysr/sr.py
CHANGED
@@ -418,8 +418,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
418 |
precision=32,
|
419 |
multithreading=None,
|
420 |
cluster_manager=None,
|
421 |
-
use_symbolic_utils=False,
|
422 |
skip_mutation_failures=True,
|
|
|
|
|
423 |
# To support deprecated kwargs:
|
424 |
**kwargs,
|
425 |
):
|
@@ -558,10 +559,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
558 |
:type tournament_selection_p: float
|
559 |
:param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
|
560 |
:type precision: int
|
561 |
-
:param use_symbolic_utils: Whether to use SymbolicUtils during simplification.
|
562 |
-
:type use_symbolic_utils: bool
|
563 |
:param skip_mutation_failures: Whether to skip mutation and crossover failures, rather than simply re-sampling the current member.
|
564 |
:type skip_mutation_failures: bool
|
|
|
|
|
|
|
|
|
565 |
:param kwargs: Supports deprecated keyword arguments. Other arguments will result
|
566 |
in an error
|
567 |
:type kwargs: dict
|
@@ -747,8 +750,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
747 |
precision=precision,
|
748 |
multithreading=multithreading,
|
749 |
cluster_manager=cluster_manager,
|
750 |
-
use_symbolic_utils=use_symbolic_utils,
|
751 |
skip_mutation_failures=skip_mutation_failures,
|
|
|
|
|
752 |
),
|
753 |
}
|
754 |
|
@@ -1310,11 +1314,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
1310 |
perturbationFactor=self.params["perturbation_factor"],
|
1311 |
annealing=self.params["annealing"],
|
1312 |
stateReturn=True, # Required for state saving.
|
1313 |
-
use_symbolic_utils=self.params["use_symbolic_utils"],
|
1314 |
progress=self.params["progress"],
|
1315 |
timeout_in_seconds=self.params["timeout_in_seconds"],
|
1316 |
crossoverProbability=self.params["crossover_probability"],
|
1317 |
skip_mutation_failures=self.params["skip_mutation_failures"],
|
|
|
|
|
1318 |
)
|
1319 |
|
1320 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
|
|
418 |
precision=32,
|
419 |
multithreading=None,
|
420 |
cluster_manager=None,
|
|
|
421 |
skip_mutation_failures=True,
|
422 |
+
max_evals=None,
|
423 |
+
early_stop_condition=None,
|
424 |
# To support deprecated kwargs:
|
425 |
**kwargs,
|
426 |
):
|
|
|
559 |
:type tournament_selection_p: float
|
560 |
:param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
|
561 |
:type precision: int
|
|
|
|
|
562 |
:param skip_mutation_failures: Whether to skip mutation and crossover failures, rather than simply re-sampling the current member.
|
563 |
:type skip_mutation_failures: bool
|
564 |
+
:param max_evals: Limits the total number of evaluations of expressions to this number.
|
565 |
+
:type max_evals: int
|
566 |
+
:param early_stop_condition: Stop the search early if this loss is reached.
|
567 |
+
:type early_stop_condition: float
|
568 |
:param kwargs: Supports deprecated keyword arguments. Other arguments will result
|
569 |
in an error
|
570 |
:type kwargs: dict
|
|
|
750 |
precision=precision,
|
751 |
multithreading=multithreading,
|
752 |
cluster_manager=cluster_manager,
|
|
|
753 |
skip_mutation_failures=skip_mutation_failures,
|
754 |
+
max_evals=max_evals,
|
755 |
+
early_stop_condition=early_stop_condition,
|
756 |
),
|
757 |
}
|
758 |
|
|
|
1314 |
perturbationFactor=self.params["perturbation_factor"],
|
1315 |
annealing=self.params["annealing"],
|
1316 |
stateReturn=True, # Required for state saving.
|
|
|
1317 |
progress=self.params["progress"],
|
1318 |
timeout_in_seconds=self.params["timeout_in_seconds"],
|
1319 |
crossoverProbability=self.params["crossover_probability"],
|
1320 |
skip_mutation_failures=self.params["skip_mutation_failures"],
|
1321 |
+
max_evals=self.params["max_evals"],
|
1322 |
+
earlyStopCondition=self.params["early_stop_condition"],
|
1323 |
)
|
1324 |
|
1325 |
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
|
pysr/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
__version__ = "0.8.
|
2 |
-
__symbolic_regression_jl_version__ = "0.
|
|
|
1 |
+
__version__ = "0.8.2"
|
2 |
+
__symbolic_regression_jl_version__ = "0.9.1"
|