MilesCranmer commited on
Commit
a15823e
1 Parent(s): c41cf33

Reduce precision of JAX tests

Browse files
Files changed (1) hide show
  1. test/test_jax.py +3 -3
test/test_jax.py CHANGED
@@ -49,7 +49,7 @@ class TestJAX(unittest.TestCase):
49
  np.testing.assert_almost_equal(
50
  np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
51
  np.square(np.cos(X.values[:, 1])), # Select feature 1
52
- decimal=4,
53
  )
54
 
55
  def test_pipeline(self):
@@ -110,5 +110,5 @@ class TestJAX(unittest.TestCase):
110
  np_output = np_prediction(X.values)
111
  jax_output = jax_prediction(X.values)
112
 
113
- np.testing.assert_almost_equal(y.values, np_output, decimal=4)
114
- np.testing.assert_almost_equal(y.values, jax_output, decimal=4)
 
49
  np.testing.assert_almost_equal(
50
  np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
51
  np.square(np.cos(X.values[:, 1])), # Select feature 1
52
+ decimal=3,
53
  )
54
 
55
  def test_pipeline(self):
 
110
  np_output = np_prediction(X.values)
111
  jax_output = jax_prediction(X.values)
112
 
113
+ np.testing.assert_almost_equal(y.values, np_output, decimal=3)
114
+ np.testing.assert_almost_equal(y.values, jax_output, decimal=3)