MilesCranmer commited on
Commit
12e6d5e
1 Parent(s): 09a7186

Move denoising functionality to separate file

Browse files
Files changed (2) hide show
  1. pysr/denoising.py +35 -0
  2. pysr/sr.py +4 -27
pysr/denoising.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions for denoising data during preprocessing."""
2
+ import numpy as np
3
+
4
+
5
+ def denoise(X, y, Xresampled=None, random_state=None):
6
+ """Denoise the dataset using a Gaussian process."""
7
+ from sklearn.gaussian_process import GaussianProcessRegressor
8
+ from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
9
+
10
+ gp_kernel = RBF(np.ones(X.shape[1])) + WhiteKernel(1e-1) + ConstantKernel()
11
+ gpr = GaussianProcessRegressor(
12
+ kernel=gp_kernel, n_restarts_optimizer=50, random_state=random_state
13
+ )
14
+ gpr.fit(X, y)
15
+
16
+ if Xresampled is not None:
17
+ return Xresampled, gpr.predict(Xresampled)
18
+
19
+ return X, gpr.predict(X)
20
+
21
+
22
+ def multi_denoise(X, y, Xresampled=None, random_state=None):
23
+ """Perform `denoise` along each column of `y` independently."""
24
+ y = np.stack(
25
+ [
26
+ denoise(X, y[:, i], Xresampled=Xresampled, random_state=random_state)[1]
27
+ for i in range(y.shape[1])
28
+ ],
29
+ axis=1,
30
+ )
31
+
32
+ if Xresampled is not None:
33
+ return Xresampled, y
34
+
35
+ return X, y
pysr/sr.py CHANGED
@@ -18,6 +18,7 @@ from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
18
  from sklearn.utils import check_array, check_consistent_length, check_random_state
19
  from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
20
 
 
21
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
22
  from .export_jax import sympy2jax
23
  from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
@@ -1506,19 +1507,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1506
  # Denoising transformation
1507
  if self.denoise:
1508
  if self.nout_ > 1:
1509
- y = np.stack(
1510
- [
1511
- _denoise(
1512
- X, y[:, i], Xresampled=Xresampled, random_state=random_state
1513
- )[1]
1514
- for i in range(self.nout_)
1515
- ],
1516
- axis=1,
1517
  )
1518
- if Xresampled is not None:
1519
- X = Xresampled
1520
  else:
1521
- X, y = _denoise(X, y, Xresampled=Xresampled, random_state=random_state)
1522
 
1523
  return X, y, variable_names, X_units, y_units
1524
 
@@ -2394,22 +2387,6 @@ def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
2394
  return chosen_idx
2395
 
2396
 
2397
- def _denoise(X, y, Xresampled=None, random_state=None):
2398
- """Denoise the dataset using a Gaussian process."""
2399
- from sklearn.gaussian_process import GaussianProcessRegressor
2400
- from sklearn.gaussian_process.kernels import RBF, ConstantKernel, WhiteKernel
2401
-
2402
- gp_kernel = RBF(np.ones(X.shape[1])) + WhiteKernel(1e-1) + ConstantKernel()
2403
- gpr = GaussianProcessRegressor(
2404
- kernel=gp_kernel, n_restarts_optimizer=50, random_state=random_state
2405
- )
2406
- gpr.fit(X, y)
2407
- if Xresampled is not None:
2408
- return Xresampled, gpr.predict(Xresampled)
2409
-
2410
- return X, gpr.predict(X)
2411
-
2412
-
2413
  # Function has not been removed only due to usage in module tests
2414
  def _handle_feature_selection(X, select_k_features, y, variable_names):
2415
  if select_k_features is not None:
 
18
  from sklearn.utils import check_array, check_consistent_length, check_random_state
19
  from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
20
 
21
+ from .denoising import denoise, multi_denoise
22
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
23
  from .export_jax import sympy2jax
24
  from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
 
1507
  # Denoising transformation
1508
  if self.denoise:
1509
  if self.nout_ > 1:
1510
+ X, y = multi_denoise(
1511
+ X, y, Xresampled=Xresampled, random_state=random_state
 
 
 
 
 
 
1512
  )
 
 
1513
  else:
1514
+ X, y = denoise(X, y, Xresampled=Xresampled, random_state=random_state)
1515
 
1516
  return X, y, variable_names, X_units, y_units
1517
 
 
2387
  return chosen_idx
2388
 
2389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2390
  # Function has not been removed only due to usage in module tests
2391
  def _handle_feature_selection(X, select_k_features, y, variable_names):
2392
  if select_k_features is not None: