Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
44dcbea
1
Parent(s):
4c9fe98
Allow functional versions of early stop condition
Browse files- pysr/sr.py +8 -3
pysr/sr.py
CHANGED
@@ -312,8 +312,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
312 |
annealing : bool, default=True
|
313 |
Whether to use annealing. You should (and it is default).
|
314 |
|
315 |
-
early_stop_condition : float, default=None
|
316 |
-
Stop the search early if this loss is reached.
|
|
|
|
|
|
|
317 |
|
318 |
ncyclesperiteration : int, default=550
|
319 |
Number of total mutations to run, per 10 samples of the
|
@@ -971,6 +974,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
971 |
|
972 |
# 'Mutable' parameter validation
|
973 |
buffer_available = "buffer" in sys.stdout.__dir__()
|
|
|
974 |
modifiable_params = {
|
975 |
"binary_operators": "+ * - /".split(" "),
|
976 |
"unary_operators": [],
|
@@ -1308,6 +1312,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1308 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
1309 |
|
1310 |
custom_loss = Main.eval(self.loss)
|
|
|
1311 |
|
1312 |
mutationWeights = [
|
1313 |
float(self.weight_mutate_constant),
|
@@ -1369,7 +1374,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1369 |
crossoverProbability=self.crossover_probability,
|
1370 |
skip_mutation_failures=self.skip_mutation_failures,
|
1371 |
max_evals=self.max_evals,
|
1372 |
-
earlyStopCondition=
|
1373 |
seed=seed,
|
1374 |
deterministic=self.deterministic,
|
1375 |
)
|
|
|
312 |
annealing : bool, default=True
|
313 |
Whether to use annealing. You should (and it is default).
|
314 |
|
315 |
+
early_stop_condition : { float | str }, default=None
|
316 |
+
Stop the search early if this loss is reached. You may also
|
317 |
+
pass a string containing a Julia function which
|
318 |
+
takes a loss and complexity as input, for example:
|
319 |
+
`"f(loss, complexity) = (loss < 0.1) && (complexity < 10)"`.
|
320 |
|
321 |
ncyclesperiteration : int, default=550
|
322 |
Number of total mutations to run, per 10 samples of the
|
|
|
974 |
|
975 |
# 'Mutable' parameter validation
|
976 |
buffer_available = "buffer" in sys.stdout.__dir__()
|
977 |
+
# Params and their default values, if None is given:
|
978 |
modifiable_params = {
|
979 |
"binary_operators": "+ * - /".split(" "),
|
980 |
"unary_operators": [],
|
|
|
1312 |
complexity_of_operators = Main.eval(complexity_of_operators_str)
|
1313 |
|
1314 |
custom_loss = Main.eval(self.loss)
|
1315 |
+
early_stop_condition = Main.eval(self.early_stop_condition)
|
1316 |
|
1317 |
mutationWeights = [
|
1318 |
float(self.weight_mutate_constant),
|
|
|
1374 |
crossoverProbability=self.crossover_probability,
|
1375 |
skip_mutation_failures=self.skip_mutation_failures,
|
1376 |
max_evals=self.max_evals,
|
1377 |
+
earlyStopCondition=early_stop_condition,
|
1378 |
seed=seed,
|
1379 |
deterministic=self.deterministic,
|
1380 |
)
|