MilesCranmer commited on
Commit
5b0cc10
1 Parent(s): 43e7deb

Reduce precision of torch tests

Browse files
Files changed (1) hide show
  1. test/test_torch.py +6 -6
test/test_torch.py CHANGED
@@ -65,7 +65,7 @@ class TestTorch(unittest.TestCase):
65
  np.testing.assert_almost_equal(
66
  tformat(torch.tensor(X.values)).detach().numpy(),
67
  np.square(np.cos(X.values[:, 1])), # Selection 1st feature
68
- decimal=4,
69
  )
70
 
71
  def test_pipeline(self):
@@ -99,7 +99,7 @@ class TestTorch(unittest.TestCase):
99
  np.testing.assert_almost_equal(
100
  tformat(torch.tensor(X)).detach().numpy(),
101
  np.square(np.cos(X[:, 1])), # 2nd feature
102
- decimal=4,
103
  )
104
 
105
  def test_mod_mapping(self):
@@ -116,7 +116,7 @@ class TestTorch(unittest.TestCase):
116
  torch_out = module(X)
117
 
118
  np.testing.assert_array_almost_equal(
119
- true_out.detach(), torch_out.detach(), decimal=4
120
  )
121
 
122
  def test_custom_operator(self):
@@ -156,7 +156,7 @@ class TestTorch(unittest.TestCase):
156
  np.testing.assert_almost_equal(
157
  tformat(torch.tensor(X)).detach().numpy(),
158
  np.sin(X[:, 1]),
159
- decimal=4,
160
  )
161
 
162
  def test_feature_selection_custom_operators(self):
@@ -186,5 +186,5 @@ class TestTorch(unittest.TestCase):
186
 
187
  torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
188
 
189
- np.testing.assert_almost_equal(y.values, np_output, decimal=4)
190
- np.testing.assert_almost_equal(y.values, torch_output, decimal=4)
 
65
  np.testing.assert_almost_equal(
66
  tformat(torch.tensor(X.values)).detach().numpy(),
67
  np.square(np.cos(X.values[:, 1])), # Selection 1st feature
68
+ decimal=3,
69
  )
70
 
71
  def test_pipeline(self):
 
99
  np.testing.assert_almost_equal(
100
  tformat(torch.tensor(X)).detach().numpy(),
101
  np.square(np.cos(X[:, 1])), # 2nd feature
102
+ decimal=3,
103
  )
104
 
105
  def test_mod_mapping(self):
 
116
  torch_out = module(X)
117
 
118
  np.testing.assert_array_almost_equal(
119
+ true_out.detach(), torch_out.detach(), decimal=3
120
  )
121
 
122
  def test_custom_operator(self):
 
156
  np.testing.assert_almost_equal(
157
  tformat(torch.tensor(X)).detach().numpy(),
158
  np.sin(X[:, 1]),
159
+ decimal=3,
160
  )
161
 
162
  def test_feature_selection_custom_operators(self):
 
186
 
187
  torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
188
 
189
+ np.testing.assert_almost_equal(y.values, np_output, decimal=3)
190
+ np.testing.assert_almost_equal(y.values, torch_output, decimal=3)