MilesCranmer commited on
Commit
beaf20b
1 Parent(s): f119733

Make JAX custom operator test deterministic

Browse files
Files changed (1) hide show
  1. test/test_jax.py +10 -3
test/test_jax.py CHANGED
@@ -80,9 +80,10 @@ class TestJAX(unittest.TestCase):
80
  )
81
 
82
  def test_feature_selection_custom_operators(self):
83
- X = pd.DataFrame({f"k{i}": np.random.randn(1000) for i in range(10, 21)})
 
84
  cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
85
- y = X["k15"] ** 2 + cos_approx(X["k20"])
86
 
87
  model = PySRRegressor(
88
  progress=False,
@@ -94,7 +95,12 @@ class TestJAX(unittest.TestCase):
94
  extra_jax_mappings={
95
  "cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
96
  },
 
 
 
 
97
  )
 
98
  model.fit(X.values, y.values)
99
  f, parameters = model.jax().values()
100
 
@@ -104,4 +110,5 @@ class TestJAX(unittest.TestCase):
104
  np_output = np_prediction(X.values)
105
  jax_output = jax_prediction(X.values)
106
 
107
- np.testing.assert_almost_equal(np_output, jax_output, decimal=4)
 
 
80
  )
81
 
82
  def test_feature_selection_custom_operators(self):
83
+ rstate = np.random.RandomState(0)
84
+ X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
85
  cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
86
+ y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
87
 
88
  model = PySRRegressor(
89
  progress=False,
 
95
  extra_jax_mappings={
96
  "cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
97
  },
98
+ random_state=0,
99
+ deterministic=True,
100
+ procs=0,
101
+ multithreading=False,
102
  )
103
+ np.random.seed(0)
104
  model.fit(X.values, y.values)
105
  f, parameters = model.jax().values()
106
 
 
110
  np_output = np_prediction(X.values)
111
  jax_output = jax_prediction(X.values)
112
 
113
+ np.testing.assert_almost_equal(y.values, np_output, decimal=4)
114
+ np.testing.assert_almost_equal(y.values, jax_output, decimal=4)