MilesCranmer commited on
Commit
03f8fff
·
1 Parent(s): 86dd9ce

Fix torch test to use numpy for assert

Browse files
Files changed (1) hide show
  1. test/test_torch.py +1 -1
test/test_torch.py CHANGED
@@ -9,4 +9,4 @@ cosx = 1.0 * sympy.cos(x) + y
9
  X = torch.randn((1000, 2))
10
  true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
11
  torch_module = sympy2torch(cosx, [x, y, z])
12
- assert jnp.all(jnp.isclose(torch_module(X), true)).item()
 
9
  X = torch.randn((1000, 2))
10
  true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
11
  torch_module = sympy2torch(cosx, [x, y, z])
12
+ assert np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))