Spaces:
Running
Running
tttc3
commited on
Commit
•
ad1c492
1
Parent(s):
3182a3b
Addressed some DeepSource issues
Browse files- pysr/sr.py +44 -30
pysr/sr.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import sys
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
-
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
6 |
import sympy
|
7 |
from sympy import sympify
|
8 |
import re
|
@@ -13,6 +12,7 @@ from datetime import datetime
|
|
13 |
import warnings
|
14 |
from multiprocessing import cpu_count
|
15 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
|
|
16 |
from sklearn.utils.validation import (
|
17 |
_check_feature_names_in,
|
18 |
check_is_fitted,
|
@@ -76,7 +76,8 @@ sympy_mappings = {
|
|
76 |
|
77 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
78 |
warnings.warn(
|
79 |
-
"Calling `pysr` is deprecated.
|
|
|
80 |
FutureWarning,
|
81 |
)
|
82 |
model = PySRRegressor(**kwargs)
|
@@ -95,7 +96,8 @@ def _process_constraints(binary_operators, unary_operators, constraints):
|
|
95 |
if op in ["plus", "sub", "+", "-"]:
|
96 |
if constraints[op][0] != constraints[op][1]:
|
97 |
raise NotImplementedError(
|
98 |
-
"You need equal constraints on both sides for - and +,
|
|
|
99 |
)
|
100 |
elif op in ["mult", "*"]:
|
101 |
# Make sure the complex expression is in the left side.
|
@@ -128,7 +130,8 @@ def _maybe_create_inline_operators(binary_operators, unary_operators):
|
|
128 |
if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
|
129 |
raise ValueError(
|
130 |
f"Invalid function name {function_name}. "
|
131 |
-
"Only alphanumeric characters, numbers,
|
|
|
132 |
)
|
133 |
op_list[i] = function_name
|
134 |
return binary_operators, unary_operators
|
@@ -154,25 +157,32 @@ def _check_assertions(
|
|
154 |
|
155 |
def best(*args, **kwargs): # pragma: no cover
|
156 |
raise NotImplementedError(
|
157 |
-
"`best` has been deprecated. Please use the `PySRRegressor` interface.
|
|
|
|
|
158 |
)
|
159 |
|
160 |
|
161 |
def best_row(*args, **kwargs): # pragma: no cover
|
162 |
raise NotImplementedError(
|
163 |
-
"`best_row` has been deprecated. Please use the `PySRRegressor` interface.
|
|
|
|
|
164 |
)
|
165 |
|
166 |
|
167 |
def best_tex(*args, **kwargs): # pragma: no cover
|
168 |
raise NotImplementedError(
|
169 |
-
"`best_tex` has been deprecated. Please use the `PySRRegressor` interface.
|
|
|
|
|
170 |
)
|
171 |
|
172 |
|
173 |
def best_callable(*args, **kwargs): # pragma: no cover
|
174 |
raise NotImplementedError(
|
175 |
-
"`best_callable` has been deprecated. Please use the `PySRRegressor`
|
|
|
176 |
)
|
177 |
|
178 |
|
@@ -775,7 +785,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
775 |
setattr(self, updated_kwarg_name, v)
|
776 |
warnings.warn(
|
777 |
f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
|
778 |
-
"
|
779 |
FutureWarning,
|
780 |
)
|
781 |
# Handle kwargs that have been moved to the fit method
|
@@ -787,7 +797,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
787 |
)
|
788 |
else:
|
789 |
raise TypeError(
|
790 |
-
f"{k} is not a valid keyword argument for PySRRegressor"
|
791 |
)
|
792 |
|
793 |
def __repr__(self):
|
@@ -964,7 +974,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
964 |
values. For example, default parameters are set here
|
965 |
when a parameter is left set to `None`.
|
966 |
"""
|
967 |
-
|
968 |
# Immutable parameter validation
|
969 |
# Ensure instance parameters are allowable values:
|
970 |
if self.tournament_selection_n > self.population_size:
|
@@ -974,27 +983,29 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
974 |
|
975 |
if self.maxsize > 40:
|
976 |
warnings.warn(
|
977 |
-
"Note: Using a large maxsize for the equation search will be
|
|
|
|
|
978 |
)
|
979 |
elif self.maxsize < 7:
|
980 |
raise ValueError("PySR requires a maxsize of at least 7")
|
981 |
|
982 |
-
if self.deterministic
|
983 |
-
|
984 |
-
|
985 |
-
|
986 |
-
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
)
|
992 |
|
993 |
-
if self.random_state
|
|
|
|
|
994 |
warnings.warn(
|
995 |
"Note: Setting `random_state` without also setting `deterministic` "
|
996 |
-
"to True and `procs` to 0 "
|
997 |
-
"will result in non-deterministic searches. "
|
998 |
)
|
999 |
|
1000 |
# NotImplementedError - Values that could be supported at a later time
|
@@ -1035,7 +1046,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1035 |
parameter_value = 1
|
1036 |
elif parameter == "progress" and not buffer_available:
|
1037 |
warnings.warn(
|
1038 |
-
"Note: it looks like you are running in Jupyter.
|
|
|
1039 |
)
|
1040 |
parameter_value = False
|
1041 |
packed_modified_params[parameter] = parameter_value
|
@@ -1087,7 +1099,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1087 |
Validated list of variable names for each feature in `X`.
|
1088 |
|
1089 |
"""
|
1090 |
-
|
1091 |
if isinstance(X, pd.DataFrame):
|
1092 |
if variable_names:
|
1093 |
variable_names = None
|
@@ -1803,7 +1814,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1803 |
)
|
1804 |
except FileNotFoundError:
|
1805 |
raise RuntimeError(
|
1806 |
-
"Couldn't find equation file! The equation search likely exited
|
|
|
1807 |
)
|
1808 |
|
1809 |
# It is expected extra_jax/torch_mappings will be updated after fit.
|
@@ -1814,7 +1826,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1814 |
for value in extra_jax_mappings.values():
|
1815 |
if not isinstance(value, str):
|
1816 |
raise ValueError(
|
1817 |
-
"extra_jax_mappings must have keys that are strings!
|
|
|
1818 |
)
|
1819 |
else:
|
1820 |
extra_jax_mappings = {}
|
@@ -1822,7 +1835,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1822 |
for value in extra_jax_mappings.values():
|
1823 |
if not callable(value):
|
1824 |
raise ValueError(
|
1825 |
-
"extra_torch_mappings must be callable functions!
|
|
|
1826 |
)
|
1827 |
else:
|
1828 |
extra_torch_mappings = {}
|
|
|
2 |
import sys
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
|
|
5 |
import sympy
|
6 |
from sympy import sympify
|
7 |
import re
|
|
|
12 |
import warnings
|
13 |
from multiprocessing import cpu_count
|
14 |
from sklearn.base import BaseEstimator, RegressorMixin, MultiOutputMixin
|
15 |
+
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
16 |
from sklearn.utils.validation import (
|
17 |
_check_feature_names_in,
|
18 |
check_is_fitted,
|
|
|
76 |
|
77 |
def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
78 |
warnings.warn(
|
79 |
+
"Calling `pysr` is deprecated. "
|
80 |
+
"Please use `model = PySRRegressor(**params); model.fit(X, y)` going forward.",
|
81 |
FutureWarning,
|
82 |
)
|
83 |
model = PySRRegressor(**kwargs)
|
|
|
96 |
if op in ["plus", "sub", "+", "-"]:
|
97 |
if constraints[op][0] != constraints[op][1]:
|
98 |
raise NotImplementedError(
|
99 |
+
"You need equal constraints on both sides for - and +, "
|
100 |
+
"due to simplification strategies."
|
101 |
)
|
102 |
elif op in ["mult", "*"]:
|
103 |
# Make sure the complex expression is in the left side.
|
|
|
130 |
if not re.match(r"^[a-zA-Z0-9_]+$", function_name):
|
131 |
raise ValueError(
|
132 |
f"Invalid function name {function_name}. "
|
133 |
+
"Only alphanumeric characters, numbers, "
|
134 |
+
"and underscores are allowed."
|
135 |
)
|
136 |
op_list[i] = function_name
|
137 |
return binary_operators, unary_operators
|
|
|
157 |
|
158 |
def best(*args, **kwargs): # pragma: no cover
|
159 |
raise NotImplementedError(
|
160 |
+
"`best` has been deprecated. Please use the `PySRRegressor` interface. "
|
161 |
+
"After fitting, you can return `.sympy()` to get the sympy representation "
|
162 |
+
"of the best equation."
|
163 |
)
|
164 |
|
165 |
|
166 |
def best_row(*args, **kwargs): # pragma: no cover
|
167 |
raise NotImplementedError(
|
168 |
+
"`best_row` has been deprecated. Please use the `PySRRegressor` interface. "
|
169 |
+
"After fitting, you can run `print(model)` to view the best equation, or "
|
170 |
+
"`model.get_best()` to return the best equation's row in `model.equations`."
|
171 |
)
|
172 |
|
173 |
|
174 |
def best_tex(*args, **kwargs): # pragma: no cover
|
175 |
raise NotImplementedError(
|
176 |
+
"`best_tex` has been deprecated. Please use the `PySRRegressor` interface. "
|
177 |
+
"After fitting, you can return `.latex()` to get the sympy representation "
|
178 |
+
"of the best equation."
|
179 |
)
|
180 |
|
181 |
|
182 |
def best_callable(*args, **kwargs): # pragma: no cover
|
183 |
raise NotImplementedError(
|
184 |
+
"`best_callable` has been deprecated. Please use the `PySRRegressor` "
|
185 |
+
"interface. After fitting, you can use `.predict(X)` to use the best callable."
|
186 |
)
|
187 |
|
188 |
|
|
|
785 |
setattr(self, updated_kwarg_name, v)
|
786 |
warnings.warn(
|
787 |
f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
|
788 |
+
"Please use that instead.",
|
789 |
FutureWarning,
|
790 |
)
|
791 |
# Handle kwargs that have been moved to the fit method
|
|
|
797 |
)
|
798 |
else:
|
799 |
raise TypeError(
|
800 |
+
f"{k} is not a valid keyword argument for PySRRegressor."
|
801 |
)
|
802 |
|
803 |
def __repr__(self):
|
|
|
974 |
values. For example, default parameters are set here
|
975 |
when a parameter is left set to `None`.
|
976 |
"""
|
|
|
977 |
# Immutable parameter validation
|
978 |
# Ensure instance parameters are allowable values:
|
979 |
if self.tournament_selection_n > self.population_size:
|
|
|
983 |
|
984 |
if self.maxsize > 40:
|
985 |
warnings.warn(
|
986 |
+
"Note: Using a large maxsize for the equation search will be "
|
987 |
+
"exponentially slower and use significant memory. You should consider "
|
988 |
+
"turning `use_frequency` to False, and perhaps use `warmup_maxsize_by`."
|
989 |
)
|
990 |
elif self.maxsize < 7:
|
991 |
raise ValueError("PySR requires a maxsize of at least 7")
|
992 |
|
993 |
+
if self.deterministic and not (
|
994 |
+
self.multithreading in [False, None]
|
995 |
+
and self.procs == 0
|
996 |
+
and self.random_state is not None
|
997 |
+
):
|
998 |
+
raise ValueError(
|
999 |
+
"To ensure deterministic searches, you must set `random_state` to a seed, "
|
1000 |
+
"`procs` to `0`, and `multithreading` to `False` or `None`."
|
1001 |
+
)
|
|
|
1002 |
|
1003 |
+
if self.random_state is not None and (
|
1004 |
+
not self.deterministic or self.procs != 0
|
1005 |
+
):
|
1006 |
warnings.warn(
|
1007 |
"Note: Setting `random_state` without also setting `deterministic` "
|
1008 |
+
"to True and `procs` to 0 will result in non-deterministic searches. "
|
|
|
1009 |
)
|
1010 |
|
1011 |
# NotImplementedError - Values that could be supported at a later time
|
|
|
1046 |
parameter_value = 1
|
1047 |
elif parameter == "progress" and not buffer_available:
|
1048 |
warnings.warn(
|
1049 |
+
"Note: it looks like you are running in Jupyter. "
|
1050 |
+
"The progress bar will be turned off."
|
1051 |
)
|
1052 |
parameter_value = False
|
1053 |
packed_modified_params[parameter] = parameter_value
|
|
|
1099 |
Validated list of variable names for each feature in `X`.
|
1100 |
|
1101 |
"""
|
|
|
1102 |
if isinstance(X, pd.DataFrame):
|
1103 |
if variable_names:
|
1104 |
variable_names = None
|
|
|
1814 |
)
|
1815 |
except FileNotFoundError:
|
1816 |
raise RuntimeError(
|
1817 |
+
"Couldn't find equation file! The equation search likely exited "
|
1818 |
+
"before a single iteration completed."
|
1819 |
)
|
1820 |
|
1821 |
# It is expected extra_jax/torch_mappings will be updated after fit.
|
|
|
1826 |
for value in extra_jax_mappings.values():
|
1827 |
if not isinstance(value, str):
|
1828 |
raise ValueError(
|
1829 |
+
"extra_jax_mappings must have keys that are strings! "
|
1830 |
+
"e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
1831 |
)
|
1832 |
else:
|
1833 |
extra_jax_mappings = {}
|
|
|
1835 |
for value in extra_jax_mappings.values():
|
1836 |
if not callable(value):
|
1837 |
raise ValueError(
|
1838 |
+
"extra_torch_mappings must be callable functions! "
|
1839 |
+
"e.g., {sympy.sqrt: torch.sqrt}."
|
1840 |
)
|
1841 |
else:
|
1842 |
extra_torch_mappings = {}
|