MilesCranmer commited on
Commit
5a621f9
·
unverified ·
1 Parent(s): f2bce1e

test: non-simplifying sympify

Browse files
Files changed (1) hide show
  1. pysr/test/test_torch.py +24 -1
pysr/test/test_torch.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import pandas as pd
5
  import sympy
6
 
 
7
  from pysr import PySRRegressor, sympy2torch
8
 
9
 
@@ -153,10 +154,32 @@ class TestTorch(unittest.TestCase):
153
  decimal=3,
154
  )
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def test_feature_selection_custom_operators(self):
157
  rstate = np.random.RandomState(0)
158
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
159
- cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
 
 
 
160
  y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
161
 
162
  model = PySRRegressor(
 
4
  import pandas as pd
5
  import sympy
6
 
7
+ import pysr
8
  from pysr import PySRRegressor, sympy2torch
9
 
10
 
 
154
  decimal=3,
155
  )
156
 
157
+ def test_avoid_simplification(self):
158
+ # SymPy should not simplify without permission
159
+ torch = self.torch
160
+ ex = pysr.export_sympy.pysr2sympy(
161
+ "square(exp(sign(0.44796443))) + 1.5 * x1",
162
+ # ^ Normally this would become exp1 and require
163
+ # its own mapping
164
+ feature_names_in=["x1"],
165
+ extra_sympy_mappings={"square": lambda x: x**2},
166
+ )
167
+ m = pysr.export_torch.sympy2torch(ex, ["x1"])
168
+ rng = np.random.RandomState(0)
169
+ X = rng.randn(10, 1)
170
+ np.testing.assert_almost_equal(
171
+ m(torch.tensor(X)).detach().numpy(),
172
+ np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
173
+ decimal=3,
174
+ )
175
+
176
  def test_feature_selection_custom_operators(self):
177
  rstate = np.random.RandomState(0)
178
  X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
179
+
180
+ def cos_approx(x):
181
+ return 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
182
+
183
  y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
184
 
185
  model = PySRRegressor(