MilesCranmer commited on
Commit
06a77ea
1 Parent(s): 09617a6

test: simplify custom complexity test

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +11 -12
pysr/test/test.py CHANGED
@@ -179,26 +179,25 @@ class TestPipeline(unittest.TestCase):
179
 
180
  def test_custom_variable_complexity(self):
181
  for case in (1, 2):
182
- y = self.X[:, [0, 1]] ** 2
183
  model = PySRRegressor(
184
- binary_operators=["*", "+"],
185
  verbosity=0,
186
  **self.default_test_kwargs,
187
- early_stop_condition=f"stop_if(l, c) = l < 1e-8 && c <= {7 if case == 1 else 5}",
188
  )
189
  if case == 1:
190
- complexity_of_variables = [2, 3] + [
191
- 100 for _ in range(self.X.shape[1] - 2)
192
- ]
193
  elif case == 2:
194
  complexity_of_variables = 2
195
- model.fit(self.X, y, complexity_of_variables=complexity_of_variables)
196
- equations = model.equations_
197
- self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
198
- self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
 
199
 
200
- self.assertEqual(model.get_best()[0]["complexity"], 5)
201
- self.assertEqual(model.get_best()[1]["complexity"], 7 if case == 1 else 5)
202
 
203
  def test_multioutput_weighted_with_callable_temp_equation(self):
204
  X = self.X.copy()
 
179
 
180
  def test_custom_variable_complexity(self):
181
  for case in (1, 2):
182
+ y = self.X[:, [0, 1]]
183
  model = PySRRegressor(
184
+ binary_operators=["+"],
185
  verbosity=0,
186
  **self.default_test_kwargs,
187
+ early_stop_condition=f"stop_if_{case}(l, c) = l < 1e-8 && c <= {3 if case == 1 else 2}",
188
  )
189
  if case == 1:
190
+ complexity_of_variables = [2, 3]
 
 
191
  elif case == 2:
192
  complexity_of_variables = 2
193
+ model.fit(
194
+ self.X[:, [0, 1]], y, complexity_of_variables=complexity_of_variables
195
+ )
196
+ self.assertLessEqual(model.get_best()[0]["loss"], 1e-8)
197
+ self.assertLessEqual(model.get_best()[1]["loss"], 1e-8)
198
 
199
+ self.assertEqual(model.get_best()[0]["complexity"], 2)
200
+ self.assertEqual(model.get_best()[1]["complexity"], 3 if case == 1 else 2)
201
 
202
  def test_multioutput_weighted_with_callable_temp_equation(self):
203
  X = self.X.copy()