MilesCranmer commited on
Commit
616ec5c
1 Parent(s): 0501132

test: list-like variable complexity

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +27 -1
pysr/test/test.py CHANGED
@@ -172,6 +172,26 @@ class TestPipeline(unittest.TestCase):
172
  self.assertLessEqual(mse1, 1e-4)
173
  self.assertLessEqual(mse2, 1e-4)
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def test_multioutput_weighted_with_callable_temp_equation(self):
176
  X = self.X.copy()
177
  y = X[:, [0, 1]] ** 2
@@ -1053,8 +1073,14 @@ class TestDimensionalConstraints(unittest.TestCase):
1053
  """This just checks the number of units passed"""
1054
  use_custom_variable_names = False
1055
  variable_names = None
 
1056
  weights = None
1057
- args = (use_custom_variable_names, variable_names, weights)
 
 
 
 
 
1058
  valid_units = [
1059
  (np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
1060
  (np.ones((10, 1)), np.ones(10), ["m/s"], None),
 
172
  self.assertLessEqual(mse1, 1e-4)
173
  self.assertLessEqual(mse2, 1e-4)
174
 
175
+ def test_custom_variable_complexity(self):
176
+ y = self.X[:, [0, 1]] ** 2
177
+ model = PySRRegressor(
178
+ binary_operators=["*", "+"],
179
+ verbosity=0,
180
+ **self.default_test_kwargs,
181
+ early_stop_condition="stop_if(l, c) = l < 1e-4 && c <= 7",
182
+ )
183
+ model.fit(
184
+ self.X,
185
+ y,
186
+ complexity_of_variables=[2, 3] + [100 for _ in range(self.X.shape[1] - 2)],
187
+ )
188
+ equations = model.equations_
189
+ self.assertLessEqual(equations[0].iloc[-1]["loss"], 1e-4)
190
+ self.assertLessEqual(equations[1].iloc[-1]["loss"], 1e-4)
191
+
192
+ self.assertEqual(model.get_best()[0]["complexity"], 5)
193
+ self.assertEqual(model.get_best()[1]["complexity"], 7)
194
+
195
  def test_multioutput_weighted_with_callable_temp_equation(self):
196
  X = self.X.copy()
197
  y = X[:, [0, 1]] ** 2
 
1073
  """This just checks the number of units passed"""
1074
  use_custom_variable_names = False
1075
  variable_names = None
1076
+ complexity_of_variables = 1
1077
  weights = None
1078
+ args = (
1079
+ use_custom_variable_names,
1080
+ variable_names,
1081
+ complexity_of_variables,
1082
+ weights,
1083
+ )
1084
  valid_units = [
1085
  (np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
1086
  (np.ones((10, 1)), np.ones(10), ["m/s"], None),