MilesCranmer commited on
Commit
04454ac
·
unverified ·
1 Parent(s): 2d025c2

Add unittests for units checks

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +39 -3
pysr/test/test.py CHANGED
@@ -19,6 +19,7 @@ from ..sr import (
19
  _handle_feature_selection,
20
  _csv_filename_to_pkl_filename,
21
  idx_model_selection,
 
22
  )
23
  from ..export_latex import to_latex
24
 
@@ -932,12 +933,47 @@ class TestDimensionalConstraints(unittest.TestCase):
932
  self.assertLess(model.get_best()["loss"], 1e-6)
933
  self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
 
936
  # TODO: add tests for:
937
  # - custom operators + dimensions
938
- # - invalid number of dimensions
939
- # - X
940
- # - y
941
  # - no constants, so that it needs to find the right fraction
942
  # - custom dimensional_constraint_penalty
943
 
 
19
  _handle_feature_selection,
20
  _csv_filename_to_pkl_filename,
21
  idx_model_selection,
22
+ _check_assertions,
23
  )
24
  from ..export_latex import to_latex
25
 
 
933
  self.assertLess(model.get_best()["loss"], 1e-6)
934
  self.assertGreater(model.equations_.query("complexity <= 2").loss.min(), 1e-6)
935
 
936
+ def test_unit_checks(self):
937
+ """This just checks the number of units passed"""
938
+ use_custom_variable_names = False
939
+ variable_names = None
940
+ weights = None
941
+ args = (use_custom_variable_names, variable_names, weights)
942
+ valid_units = [
943
+ (np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
944
+ (np.ones((10, 1)), np.ones(10), ["m/s"], None),
945
+ (np.ones((10, 1)), np.ones(10), None, "m/s"),
946
+ (np.ones((10, 1)), np.ones(10), None, ["m/s"]),
947
+ (np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
948
+ (np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", "km"]),
949
+ ]
950
+ for X, y, X_units, y_units in valid_units:
951
+ _check_assertions(
952
+ X,
953
+ *args,
954
+ y,
955
+ X_units,
956
+ y_units,
957
+ )
958
+ invalid_units = [
959
+ (np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
960
+ (np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "m"),
961
+ (np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
962
+ (np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
963
+ ]
964
+ for X, y, X_units, y_units in invalid_units:
965
+ with self.assertRaises(ValueError):
966
+ _check_assertions(
967
+ X,
968
+ *args,
969
+ y,
970
+ X_units,
971
+ y_units,
972
+ )
973
+
974
 
975
  # TODO: add tests for:
976
  # - custom operators + dimensions
 
 
 
977
  # - no constants, so that it needs to find the right fraction
978
  # - custom dimensional_constraint_penalty
979