MilesCranmer commited on
Commit
7cda629
1 Parent(s): fad18e8

Add unit tests for custom torch/jax operations

Browse files
Files changed (2) hide show
  1. test/test_jax.py +9 -3
  2. test/test_torch.py +7 -3
test/test_jax.py CHANGED
@@ -79,15 +79,21 @@ class TestJAX(unittest.TestCase):
79
  decimal=4,
80
  )
81
 
82
- def test_feature_selection(self):
83
  X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
84
- y = X["k15"] ** 2 + np.cos(X["k20"])
 
85
 
86
  model = PySRRegressor(
87
  progress=False,
88
- unary_operators=["cos"],
89
  select_k_features=3,
 
90
  early_stop_condition=1e-5,
 
 
 
 
91
  )
92
  model.fit(X.values, y.values)
93
  f, parameters = model.jax().values()
 
79
  decimal=4,
80
  )
81
 
82
+ def test_feature_selection_custom_operators(self):
83
  X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
84
+ cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
85
+ y = X["k15"] ** 2 + cos_approx(X["k20"])
86
 
87
  model = PySRRegressor(
88
  progress=False,
89
+ unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
90
  select_k_features=3,
91
+ maxsize=10,
92
  early_stop_condition=1e-5,
93
+ extra_sympy_mappings={"cos_approx": cos_approx},
94
+ extra_jax_mappings={
95
+ "cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
96
+ },
97
  )
98
  model.fit(X.values, y.values)
99
  f, parameters = model.jax().values()
test/test_torch.py CHANGED
@@ -159,15 +159,19 @@ class TestTorch(unittest.TestCase):
159
  decimal=4,
160
  )
161
 
162
- def test_feature_selection(self):
163
  X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
164
- y = X["k15"] ** 2 + np.cos(X["k20"])
 
165
 
166
  model = PySRRegressor(
167
  progress=False,
168
- unary_operators=["cos"],
169
  select_k_features=3,
 
170
  early_stop_condition=1e-5,
 
 
171
  )
172
  model.fit(X.values, y.values)
173
  torch_module = model.pytorch()
 
159
  decimal=4,
160
  )
161
 
162
+ def test_feature_selection_custom_operators(self):
163
  X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
164
+ cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
165
+ y = X["k15"] ** 2 + cos_approx(X["k20"])
166
 
167
  model = PySRRegressor(
168
  progress=False,
169
+ unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
170
  select_k_features=3,
171
+ maxsize=10,
172
  early_stop_condition=1e-5,
173
+ extra_sympy_mappings={"cos_approx": cos_approx},
174
+ extra_torch_mappings={"cos_approx": cos_approx},
175
  )
176
  model.fit(X.values, y.values)
177
  torch_module = model.pytorch()