MilesCranmer commited on
Commit
dca02e2
1 Parent(s): aef1f27

Fix pytorch test

Browse files
Files changed (1) hide show
  1. test/test_torch.py +17 -11
test/test_torch.py CHANGED
@@ -1,7 +1,7 @@
1
  import unittest
2
  import numpy as np
3
  import pandas as pd
4
- from pysr import sympy2torch, get_hof, PySRRegressor
5
  import torch
6
  import sympy
7
 
@@ -24,7 +24,7 @@ class TestTorch(unittest.TestCase):
24
  X = np.random.randn(100, 10)
25
  equations = pd.DataFrame(
26
  {
27
- "Equation": ["1.0", "cos(x0)", "square(cos(x0))"],
28
  "MSE": [1.0, 0.1, 1e-5],
29
  "Complexity": [1, 2, 3],
30
  }
@@ -34,9 +34,9 @@ class TestTorch(unittest.TestCase):
34
  "equation_file.csv.bkup", sep="|"
35
  )
36
 
37
- equations = get_hof(
38
- "equation_file.csv",
39
- n_features=2, # TODO: Why is this 2 and not 3?
40
  variables_names="x1 x2 x3".split(" "),
41
  extra_sympy_mappings={},
42
  output_torch_format=True,
@@ -44,8 +44,12 @@ class TestTorch(unittest.TestCase):
44
  nout=1,
45
  selection=[1, 2, 3],
46
  )
 
 
 
47
 
48
- tformat = equations.iloc[-1].torch_format
 
49
  np.testing.assert_almost_equal(
50
  tformat(torch.tensor(X)).detach().numpy(),
51
  np.square(np.cos(X[:, 1])), # Selection 1st feature
@@ -84,9 +88,9 @@ class TestTorch(unittest.TestCase):
84
  "equation_file_custom_operator.csv.bkup", sep="|"
85
  )
86
 
87
- get_hof(
88
- "equation_file_custom_operator.csv",
89
- n_features=3,
90
  variables_names="x1 x2 x3".split(" "),
91
  extra_sympy_mappings={"mycustomoperator": sympy.sin},
92
  extra_torch_mappings={"mycustomoperator": torch.sin},
@@ -95,10 +99,12 @@ class TestTorch(unittest.TestCase):
95
  nout=1,
96
  selection=[0, 1, 2],
97
  )
98
-
99
- model = PySRRegressor()
 
100
  # Will automatically use the set global state from get_hof.
101
  tformat = model.pytorch()
 
102
 
103
  np.testing.assert_almost_equal(
104
  tformat(torch.tensor(X)).detach().numpy(),
 
1
  import unittest
2
  import numpy as np
3
  import pandas as pd
4
+ from pysr import sympy2torch, PySRRegressor
5
  import torch
6
  import sympy
7
 
 
24
  X = np.random.randn(100, 10)
25
  equations = pd.DataFrame(
26
  {
27
+ "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
28
  "MSE": [1.0, 0.1, 1e-5],
29
  "Complexity": [1, 2, 3],
30
  }
 
34
  "equation_file.csv.bkup", sep="|"
35
  )
36
 
37
+ model = PySRRegressor(
38
+ model_selection="accuracy",
39
+ equation_file="equation_file.csv",
40
  variables_names="x1 x2 x3".split(" "),
41
  extra_sympy_mappings={},
42
  output_torch_format=True,
 
44
  nout=1,
45
  selection=[1, 2, 3],
46
  )
47
+ model.n_features = 2 # TODO: Why is this 2 and not 3?
48
+ model.using_pandas = False
49
+ model.refresh()
50
 
51
+ tformat = model.pytorch()
52
+ self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")
53
  np.testing.assert_almost_equal(
54
  tformat(torch.tensor(X)).detach().numpy(),
55
  np.square(np.cos(X[:, 1])), # Selection 1st feature
 
88
  "equation_file_custom_operator.csv.bkup", sep="|"
89
  )
90
 
91
+ model = PySRRegressor(
92
+ model_selection="accuracy",
93
+ equation_file="equation_file_custom_operator.csv",
94
  variables_names="x1 x2 x3".split(" "),
95
  extra_sympy_mappings={"mycustomoperator": sympy.sin},
96
  extra_torch_mappings={"mycustomoperator": torch.sin},
 
99
  nout=1,
100
  selection=[0, 1, 2],
101
  )
102
+ model.n_features = 3
103
+ model.using_pandas = False
104
+ model.refresh()
105
  # Will automatically use the set global state from get_hof.
106
  tformat = model.pytorch()
107
+ self.assertEqual(str(tformat), "_SingleSymPyModule(expression=sin(x0))")
108
 
109
  np.testing.assert_almost_equal(
110
  tformat(torch.tensor(X)).detach().numpy(),