MilesCranmer commited on
Commit
cb6939e
1 Parent(s): 41e5fd5

Add missing symbol to test

Browse files
Files changed (1) hide show
  1. test/test_jax.py +1 -1
test/test_jax.py CHANGED
@@ -11,5 +11,5 @@ cosx = 1.0 * sympy.cos(x) + y
11
  key = random.PRNGKey(0)
12
  X = random.normal(key, (1000, 2))
13
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
14
- f, params = sympy2jax(cosx, [x])
15
  assert jnp.all(jnp.isclose(f(X, params), true)).item()
 
11
  key = random.PRNGKey(0)
12
  X = random.normal(key, (1000, 2))
13
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
14
+ f, params = sympy2jax(cosx, [x, y, z])
15
  assert jnp.all(jnp.isclose(f(X, params), true)).item()