Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
c3a1736
1
Parent(s):
91e26f9
Fix torch test
Browse files- test/test_torch.py +4 -3
test/test_torch.py
CHANGED
@@ -76,7 +76,7 @@ class TestTorch(unittest.TestCase):
|
|
76 |
|
77 |
equations = pd.DataFrame(
|
78 |
{
|
79 |
-
"Equation": ["1.0", "mycustomoperator(
|
80 |
"MSE": [1.0, 0.1],
|
81 |
"Complexity": [1, 2],
|
82 |
}
|
@@ -98,10 +98,11 @@ class TestTorch(unittest.TestCase):
|
|
98 |
model.n_features = 3
|
99 |
model.using_pandas = False
|
100 |
model.refresh()
|
|
|
101 |
# Will automatically use the set global state from get_hof.
|
102 |
-
tformat = model.pytorch()
|
103 |
-
self.assertEqual(str(tformat), "_SingleSymPyModule(expression=sin(x0))")
|
104 |
|
|
|
|
|
105 |
np.testing.assert_almost_equal(
|
106 |
tformat(torch.tensor(X)).detach().numpy(),
|
107 |
np.sin(X[:, 0]), # Selection 1st feature
|
|
|
76 |
|
77 |
equations = pd.DataFrame(
|
78 |
{
|
79 |
+
"Equation": ["1.0", "mycustomoperator(x1)"],
|
80 |
"MSE": [1.0, 0.1],
|
81 |
"Complexity": [1, 2],
|
82 |
}
|
|
|
98 |
model.n_features = 3
|
99 |
model.using_pandas = False
|
100 |
model.refresh()
|
101 |
+
self.assertEqual(str(model.sympy()), "sin(x1)")
|
102 |
# Will automatically use the set global state from get_hof.
|
|
|
|
|
103 |
|
104 |
+
tformat = model.pytorch()
|
105 |
+
self.assertEqual(str(tformat), "_SingleSymPyModule(expression=sin(x1))")
|
106 |
np.testing.assert_almost_equal(
|
107 |
tformat(torch.tensor(X)).detach().numpy(),
|
108 |
np.sin(X[:, 0]), # Selection 1st feature
|