Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
5b0cc10
1
Parent(s):
43e7deb
Reduce precision of torch tests
Browse files- test/test_torch.py +6 -6
test/test_torch.py
CHANGED
@@ -65,7 +65,7 @@ class TestTorch(unittest.TestCase):
|
|
65 |
np.testing.assert_almost_equal(
|
66 |
tformat(torch.tensor(X.values)).detach().numpy(),
|
67 |
np.square(np.cos(X.values[:, 1])), # Selection 1st feature
|
68 |
-
decimal=
|
69 |
)
|
70 |
|
71 |
def test_pipeline(self):
|
@@ -99,7 +99,7 @@ class TestTorch(unittest.TestCase):
|
|
99 |
np.testing.assert_almost_equal(
|
100 |
tformat(torch.tensor(X)).detach().numpy(),
|
101 |
np.square(np.cos(X[:, 1])), # 2nd feature
|
102 |
-
decimal=
|
103 |
)
|
104 |
|
105 |
def test_mod_mapping(self):
|
@@ -116,7 +116,7 @@ class TestTorch(unittest.TestCase):
|
|
116 |
torch_out = module(X)
|
117 |
|
118 |
np.testing.assert_array_almost_equal(
|
119 |
-
true_out.detach(), torch_out.detach(), decimal=
|
120 |
)
|
121 |
|
122 |
def test_custom_operator(self):
|
@@ -156,7 +156,7 @@ class TestTorch(unittest.TestCase):
|
|
156 |
np.testing.assert_almost_equal(
|
157 |
tformat(torch.tensor(X)).detach().numpy(),
|
158 |
np.sin(X[:, 1]),
|
159 |
-
decimal=
|
160 |
)
|
161 |
|
162 |
def test_feature_selection_custom_operators(self):
|
@@ -186,5 +186,5 @@ class TestTorch(unittest.TestCase):
|
|
186 |
|
187 |
torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
|
188 |
|
189 |
-
np.testing.assert_almost_equal(y.values, np_output, decimal=
|
190 |
-
np.testing.assert_almost_equal(y.values, torch_output, decimal=
|
|
|
65 |
np.testing.assert_almost_equal(
|
66 |
tformat(torch.tensor(X.values)).detach().numpy(),
|
67 |
np.square(np.cos(X.values[:, 1])), # Selection 1st feature
|
68 |
+
decimal=3,
|
69 |
)
|
70 |
|
71 |
def test_pipeline(self):
|
|
|
99 |
np.testing.assert_almost_equal(
|
100 |
tformat(torch.tensor(X)).detach().numpy(),
|
101 |
np.square(np.cos(X[:, 1])), # 2nd feature
|
102 |
+
decimal=3,
|
103 |
)
|
104 |
|
105 |
def test_mod_mapping(self):
|
|
|
116 |
torch_out = module(X)
|
117 |
|
118 |
np.testing.assert_array_almost_equal(
|
119 |
+
true_out.detach(), torch_out.detach(), decimal=3
|
120 |
)
|
121 |
|
122 |
def test_custom_operator(self):
|
|
|
156 |
np.testing.assert_almost_equal(
|
157 |
tformat(torch.tensor(X)).detach().numpy(),
|
158 |
np.sin(X[:, 1]),
|
159 |
+
decimal=3,
|
160 |
)
|
161 |
|
162 |
def test_feature_selection_custom_operators(self):
|
|
|
186 |
|
187 |
torch_output = torch_module(torch.tensor(X.values)).detach().numpy()
|
188 |
|
189 |
+
np.testing.assert_almost_equal(y.values, np_output, decimal=3)
|
190 |
+
np.testing.assert_almost_equal(y.values, torch_output, decimal=3)
|