MilesCranmer commited on
Commit
c3a1736
1 Parent(s): 91e26f9

Fix torch test

Browse files
Files changed (1) hide show
  1. test/test_torch.py +4 -3
test/test_torch.py CHANGED
@@ -76,7 +76,7 @@ class TestTorch(unittest.TestCase):
76
 
77
  equations = pd.DataFrame(
78
  {
79
- "Equation": ["1.0", "mycustomoperator(x0)"],
80
  "MSE": [1.0, 0.1],
81
  "Complexity": [1, 2],
82
  }
@@ -98,10 +98,11 @@ class TestTorch(unittest.TestCase):
98
  model.n_features = 3
99
  model.using_pandas = False
100
  model.refresh()
 
101
  # Will automatically use the set global state from get_hof.
102
- tformat = model.pytorch()
103
- self.assertEqual(str(tformat), "_SingleSymPyModule(expression=sin(x0))")
104
 
 
 
105
  np.testing.assert_almost_equal(
106
  tformat(torch.tensor(X)).detach().numpy(),
107
  np.sin(X[:, 0]), # Selection 1st feature
 
76
 
77
  equations = pd.DataFrame(
78
  {
79
+ "Equation": ["1.0", "mycustomoperator(x1)"],
80
  "MSE": [1.0, 0.1],
81
  "Complexity": [1, 2],
82
  }
 
98
  model.n_features = 3
99
  model.using_pandas = False
100
  model.refresh()
101
+ self.assertEqual(str(model.sympy()), "sin(x1)")
102
  # Will automatically use the set global state from get_hof.
 
 
103
 
104
+ tformat = model.pytorch()
105
+ self.assertEqual(str(tformat), "_SingleSymPyModule(expression=sin(x1))")
106
  np.testing.assert_almost_equal(
107
  tformat(torch.tensor(X)).detach().numpy(),
108
  np.sin(X[:, 0]), # Selection 1st feature