MilesCranmer commited on
Commit
ce5b119
1 Parent(s): 775c667

Add test for feature selection in JAX output

Browse files
Files changed (1) hide show
  1. test/test_jax.py +19 -0
test/test_jax.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
7
  import sympy
 
8
 
9
 
10
  class TestJAX(unittest.TestCase):
@@ -79,3 +80,21 @@ class TestJAX(unittest.TestCase):
79
  np.square(np.cos(X[:, 1])), # Select feature 1
80
  decimal=4,
81
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from jax import numpy as jnp
6
  from jax import random
7
  import sympy
8
+ from functools import partial
9
 
10
 
11
  class TestJAX(unittest.TestCase):
 
80
  np.square(np.cos(X[:, 1])), # Select feature 1
81
  decimal=4,
82
  )
83
+
84
+ def test_feature_selection(self):
85
+ X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
86
+ y = X["k15"] ** 2 + np.cos(X["k20"])
87
+
88
+ model = PySRRegressor(
89
+ unary_operators=["cos"], select_k_features=3, early_stop_condition=1e-5
90
+ )
91
+ model.fit(X.values, y.values)
92
+ f, parameters = model.jax().values()
93
+
94
+ np_prediction = model.predict
95
+ jax_prediction = partial(f, parameters=parameters)
96
+
97
+ np_output = np_prediction(X.values)
98
+ jax_output = jax_prediction(X.values)
99
+
100
+ np.testing.assert_almost_equal(np_output, jax_output, decimal=4)