Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
12e6d5e
1
Parent(s):
09a7186
Move denoising functionality to separate file
Browse files- pysr/denoising.py +35 -0
- pysr/sr.py +4 -27
pysr/denoising.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Functions for denoising data during preprocessing."""
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def denoise(X, y, Xresampled=None, random_state=None):
|
6 |
+
"""Denoise the dataset using a Gaussian process."""
|
7 |
+
from sklearn.gaussian_process import GaussianProcessRegressor
|
8 |
+
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
|
9 |
+
|
10 |
+
gp_kernel = RBF(np.ones(X.shape[1])) + WhiteKernel(1e-1) + ConstantKernel()
|
11 |
+
gpr = GaussianProcessRegressor(
|
12 |
+
kernel=gp_kernel, n_restarts_optimizer=50, random_state=random_state
|
13 |
+
)
|
14 |
+
gpr.fit(X, y)
|
15 |
+
|
16 |
+
if Xresampled is not None:
|
17 |
+
return Xresampled, gpr.predict(Xresampled)
|
18 |
+
|
19 |
+
return X, gpr.predict(X)
|
20 |
+
|
21 |
+
|
22 |
+
def multi_denoise(X, y, Xresampled=None, random_state=None):
|
23 |
+
"""Perform `denoise` along each column of `y` independently."""
|
24 |
+
y = np.stack(
|
25 |
+
[
|
26 |
+
denoise(X, y[:, i], Xresampled=Xresampled, random_state=random_state)[1]
|
27 |
+
for i in range(y.shape[1])
|
28 |
+
],
|
29 |
+
axis=1,
|
30 |
+
)
|
31 |
+
|
32 |
+
if Xresampled is not None:
|
33 |
+
return Xresampled, y
|
34 |
+
|
35 |
+
return X, y
|
pysr/sr.py
CHANGED
@@ -18,6 +18,7 @@ from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
|
|
18 |
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
19 |
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
|
20 |
|
|
|
21 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
22 |
from .export_jax import sympy2jax
|
23 |
from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
|
@@ -1506,19 +1507,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1506 |
# Denoising transformation
|
1507 |
if self.denoise:
|
1508 |
if self.nout_ > 1:
|
1509 |
-
y =
|
1510 |
-
|
1511 |
-
_denoise(
|
1512 |
-
X, y[:, i], Xresampled=Xresampled, random_state=random_state
|
1513 |
-
)[1]
|
1514 |
-
for i in range(self.nout_)
|
1515 |
-
],
|
1516 |
-
axis=1,
|
1517 |
)
|
1518 |
-
if Xresampled is not None:
|
1519 |
-
X = Xresampled
|
1520 |
else:
|
1521 |
-
X, y =
|
1522 |
|
1523 |
return X, y, variable_names, X_units, y_units
|
1524 |
|
@@ -2394,22 +2387,6 @@ def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
|
|
2394 |
return chosen_idx
|
2395 |
|
2396 |
|
2397 |
-
def _denoise(X, y, Xresampled=None, random_state=None):
|
2398 |
-
"""Denoise the dataset using a Gaussian process."""
|
2399 |
-
from sklearn.gaussian_process import GaussianProcessRegressor
|
2400 |
-
from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
|
2401 |
-
|
2402 |
-
gp_kernel = RBF(np.ones(X.shape[1])) + WhiteKernel(1e-1) + ConstantKernel()
|
2403 |
-
gpr = GaussianProcessRegressor(
|
2404 |
-
kernel=gp_kernel, n_restarts_optimizer=50, random_state=random_state
|
2405 |
-
)
|
2406 |
-
gpr.fit(X, y)
|
2407 |
-
if Xresampled is not None:
|
2408 |
-
return Xresampled, gpr.predict(Xresampled)
|
2409 |
-
|
2410 |
-
return X, gpr.predict(X)
|
2411 |
-
|
2412 |
-
|
2413 |
# Function has not been removed only due to usage in module tests
|
2414 |
def _handle_feature_selection(X, select_k_features, y, variable_names):
|
2415 |
if select_k_features is not None:
|
|
|
18 |
from sklearn.utils import check_array, check_consistent_length, check_random_state
|
19 |
from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
|
20 |
|
21 |
+
from .denoising import denoise, multi_denoise
|
22 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
23 |
from .export_jax import sympy2jax
|
24 |
from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
|
|
|
1507 |
# Denoising transformation
|
1508 |
if self.denoise:
|
1509 |
if self.nout_ > 1:
|
1510 |
+
X, y = multi_denoise(
|
1511 |
+
X, y, Xresampled=Xresampled, random_state=random_state
|
|
|
|
|
|
|
|
|
|
|
|
|
1512 |
)
|
|
|
|
|
1513 |
else:
|
1514 |
+
X, y = denoise(X, y, Xresampled=Xresampled, random_state=random_state)
|
1515 |
|
1516 |
return X, y, variable_names, X_units, y_units
|
1517 |
|
|
|
2387 |
return chosen_idx
|
2388 |
|
2389 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2390 |
# Function has not been removed only due to usage in module tests
|
2391 |
def _handle_feature_selection(X, select_k_features, y, variable_names):
|
2392 |
if select_k_features is not None:
|