Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
a15823e
1
Parent(s):
c41cf33
Reduce precision of JAX tests
Browse files- test/test_jax.py +3 -3
test/test_jax.py
CHANGED
@@ -49,7 +49,7 @@ class TestJAX(unittest.TestCase):
|
|
49 |
np.testing.assert_almost_equal(
|
50 |
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
51 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
52 |
-
decimal=
|
53 |
)
|
54 |
|
55 |
def test_pipeline(self):
|
@@ -110,5 +110,5 @@ class TestJAX(unittest.TestCase):
|
|
110 |
np_output = np_prediction(X.values)
|
111 |
jax_output = jax_prediction(X.values)
|
112 |
|
113 |
-
np.testing.assert_almost_equal(y.values, np_output, decimal=
|
114 |
-
np.testing.assert_almost_equal(y.values, jax_output, decimal=
|
|
|
49 |
np.testing.assert_almost_equal(
|
50 |
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
51 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
52 |
+
decimal=3,
|
53 |
)
|
54 |
|
55 |
def test_pipeline(self):
|
|
|
110 |
np_output = np_prediction(X.values)
|
111 |
jax_output = jax_prediction(X.values)
|
112 |
|
113 |
+
np.testing.assert_almost_equal(y.values, np_output, decimal=3)
|
114 |
+
np.testing.assert_almost_equal(y.values, jax_output, decimal=3)
|