MilesCranmer commited on
Commit
0e15dd6
1 Parent(s): a117981

Test multiple output units

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +6 -5
pysr/test/test.py CHANGED
@@ -920,7 +920,7 @@ class TestDimensionalConstraints(unittest.TestCase):
920
  self.X = self.rstate.randn(100, 5)
921
 
922
  def test_dimensional_constraints(self):
923
- y = np.cos(self.X[:, 0])
924
  model = PySRRegressor(
925
  binary_operators=[
926
  "my_add(x, y) = x + y",
@@ -938,12 +938,13 @@ class TestDimensionalConstraints(unittest.TestCase):
938
  "my_mul": lambda x, y: x * y,
939
  },
940
  )
941
- model.fit(self.X, y, X_units=["m", "m", "m", "m", "m"], y_units="m")
942
 
943
  # The best expression should have complexity larger than just 2:
944
- self.assertGreater(model.get_best()["complexity"], 2)
945
- self.assertLess(model.get_best()["loss"], 1e-6)
946
- self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
 
947
 
948
  def test_unit_checks(self):
949
  """This just checks the number of units passed"""
 
920
  self.X = self.rstate.randn(100, 5)
921
 
922
  def test_dimensional_constraints(self):
923
+ y = np.cos(self.X[:, [0, 1]])
924
  model = PySRRegressor(
925
  binary_operators=[
926
  "my_add(x, y) = x + y",
 
938
  "my_mul": lambda x, y: x * y,
939
  },
940
  )
941
+ model.fit(self.X, y, X_units=["m", "m", "m", "m", "m"], y_units=["m", "m"])
942
 
943
  # The best expression should have complexity larger than just 2:
944
+ for i in range(2):
945
+ self.assertGreater(model.get_best()[i]["complexity"], 2)
946
+ self.assertLess(model.get_best()[i]["loss"], 1e-6)
947
+ self.assertGreater(model.equations_[i].query("complexity <= 2").loss.min(), 1e-6)
948
 
949
  def test_unit_checks(self):
950
  """This just checks the number of units passed"""