MilesCranmer commited on
Commit
c9cead8
1 Parent(s): 7cda629

Make torch custom operator test deterministic

Browse files
Files changed (1) hide show
  1. test/test_torch.py +10 -3
test/test_torch.py CHANGED
@@ -160,9 +160,10 @@ class TestTorch(unittest.TestCase):
160
  )
161
 
162
  def test_feature_selection_custom_operators(self):
163
- X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
 
164
  cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
165
- y = X["k15"] ** 2 + cos_approx(X["k20"])
166
 
167
  model = PySRRegressor(
168
  progress=False,
@@ -172,7 +173,12 @@ class TestTorch(unittest.TestCase):
172
  early_stop_condition=1e-5,
173
  extra_sympy_mappings={"cos_approx": cos_approx},
174
  extra_torch_mappings={"cos_approx": cos_approx},
 
 
 
 
175
  )
 
176
  model.fit(X.values, y.values)
177
  torch_module = model.pytorch()
178
 
@@ -180,4 +186,5 @@ class TestTorch(unittest.TestCase):
180
 
181
  torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
182
 
183
- np.testing.assert_almost_equal(np_output, torch_output, decimal=4)
 
 
160
  )
161
 
162
  def test_feature_selection_custom_operators(self):
163
+ rstate = np.random.RandomState(0)
164
+ X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
165
  cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
166
+ y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
167
 
168
  model = PySRRegressor(
169
  progress=False,
 
173
  early_stop_condition=1e-5,
174
  extra_sympy_mappings={"cos_approx": cos_approx},
175
  extra_torch_mappings={"cos_approx": cos_approx},
176
+ random_state=0,
177
+ deterministic=True,
178
+ procs=0,
179
+ multithreading=False,
180
  )
181
+ np.random.seed(0)
182
  model.fit(X.values, y.values)
183
  torch_module = model.pytorch()
184
 
 
186
 
187
  torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
188
 
189
+ np.testing.assert_almost_equal(y.values, np_output, decimal=4)
190
+ np.testing.assert_almost_equal(y.values, torch_output, decimal=4)