MilesCranmer commited on
Commit
962c25c
1 Parent(s): 4d5aec3

Add test for mod mapping in torch

Browse files
Files changed (1) hide show
  1. test/test_torch.py +17 -0
test/test_torch.py CHANGED
@@ -51,3 +51,20 @@ class TestTorch(unittest.TestCase):
51
  np.square(np.cos(X[:, 1])), # Selection 1st feature
52
  decimal=4,
53
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  np.square(np.cos(X[:, 1])), # Selection 1st feature
52
  decimal=4,
53
  )
54
+
55
+ def test_mod_mapping(self):
56
+ x, y, z = sympy.symbols("x y z")
57
+ expression = x ** 2 + sympy.atanh(sympy.Mod(y + 1, 2) - 1) * 3.2 * z
58
+
59
+ module = sympy2torch(expression, [x, y, z])
60
+
61
+ X = torch.rand(100, 3).float() * 10
62
+
63
+ true_out = (
64
+ X[:, 0] ** 2 + torch.atanh(torch.fmod(X[:, 1] + 1, 2) - 1) * 3.2 * X[:, 2]
65
+ )
66
+ torch_out = module(X)
67
+
68
+ np.testing.assert_array_almost_equal(
69
+ true_out.detach(), torch_out.detach(), decimal=4
70
+ )