MilesCranmer commited on
Commit
d4d95e5
1 Parent(s): 7d19ebb

Add test for custom torch operator

Browse files
Files changed (1) hide show
  1. test/test_torch.py +35 -1
test/test_torch.py CHANGED
@@ -36,7 +36,7 @@ class TestTorch(unittest.TestCase):
36
 
37
  equations = get_hof(
38
  "equation_file.csv",
39
- n_features=2,
40
  variables_names="x1 x2 x3".split(" "),
41
  extra_sympy_mappings={},
42
  output_torch_format=True,
@@ -68,3 +68,37 @@ class TestTorch(unittest.TestCase):
68
  np.testing.assert_array_almost_equal(
69
  true_out.detach(), torch_out.detach(), decimal=4
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  equations = get_hof(
38
  "equation_file.csv",
39
+ n_features=2, # TODO: Why is this 2 and not 3?
40
  variables_names="x1 x2 x3".split(" "),
41
  extra_sympy_mappings={},
42
  output_torch_format=True,
 
68
  np.testing.assert_array_almost_equal(
69
  true_out.detach(), torch_out.detach(), decimal=4
70
  )
71
+
72
+ def test_custom_operator(self):
73
+ X = np.random.randn(100, 3)
74
+
75
+ equations = pd.DataFrame(
76
+ {
77
+ "Equation": ["1.0", "mycustomoperator(x0)"],
78
+ "MSE": [1.0, 0.1],
79
+ "Complexity": [1, 2],
80
+ }
81
+ )
82
+
83
+ equations["Complexity MSE Equation".split(" ")].to_csv(
84
+ "equation_file_custom_operator.csv.bkup", sep="|"
85
+ )
86
+
87
+ equations = get_hof(
88
+ "equation_file_custom_operator.csv",
89
+ n_features=3,
90
+ variables_names="x1 x2 x3".split(" "),
91
+ extra_sympy_mappings={"mycustomoperator": sympy.sin},
92
+ extra_torch_mappings={"mycustomoperator": torch.sin},
93
+ output_torch_format=True,
94
+ multioutput=False,
95
+ nout=1,
96
+ selection=[0, 1, 2],
97
+ )
98
+
99
+ tformat = equations.iloc[-1].torch_format
100
+ np.testing.assert_almost_equal(
101
+ tformat(torch.tensor(X)).detach().numpy(),
102
+ np.sin(X[:, 0]), # Selection 1st feature
103
+ decimal=4,
104
+ )