MilesCranmer commited on
Commit
90d3ef7
1 Parent(s): 0857108

Add unittest for complex numbers

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +14 -0
pysr/test/test.py CHANGED
@@ -181,6 +181,20 @@ class TestPipeline(unittest.TestCase):
181
  print("Model equations: ", model.sympy()[1])
182
  print("True equation: x1^2")
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  def test_empty_operators_single_input_warm_start(self):
185
  X = self.rstate.randn(100, 1)
186
  y = X[:, 0] + 3.0
 
181
  print("Model equations: ", model.sympy()[1])
182
  print("True equation: x1^2")
183
 
184
+ def test_complex_equations_anonymous_stop(self):
185
+ X = self.rstate.randn(100, 3) + 1j * self.rstate.randn(100, 3)
186
+ y = (2 + 1j) * np.cos(X[:, 0] * (0.5 - 0.3j))
187
+ model = PySRRegressor(
188
+ binary_operators=["+", "-", "*"],
189
+ unary_operators=["cos"],
190
+ **self.default_test_kwargs,
191
+ early_stop_condition="(loss, complexity) -> loss <= 1e-4 && complexity <= 6",
192
+ )
193
+ model.fit(X, y)
194
+ test_y = model.predict(X)
195
+ self.assertTrue(np.issubdtype(test_y.dtype, np.complexfloating))
196
+ self.assertLessEqual(np.average(np.abs(test_y - y) ** 2), 1e-4)
197
+
198
  def test_empty_operators_single_input_warm_start(self):
199
  X = self.rstate.randn(100, 1)
200
  y = X[:, 0] + 3.0