MilesCranmer commited on
Commit
7946ec0
·
unverified ·
2 Parent(s): b3a5026 7091a55

Merge pull request #545 from MilesCranmer/fix-units

Browse files
Files changed (3) hide show
  1. pyproject.toml +1 -1
  2. pysr/sr.py +3 -1
  3. pysr/test/test.py +2 -2
pyproject.toml CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
 
5
  [project]
6
  name = "pysr"
7
- version = "0.17.0"
8
  authors = [
9
  {name = "Miles Cranmer", email = "[email protected]"},
10
  ]
 
4
 
5
  [project]
6
  name = "pysr"
7
+ version = "0.17.1"
8
  authors = [
9
  {name = "Miles Cranmer", email = "[email protected]"},
10
  ]
pysr/sr.py CHANGED
@@ -1736,7 +1736,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1736
  ),
1737
  y_variable_names=jl_y_variable_names,
1738
  X_units=jl_array(self.X_units_),
1739
- y_units=jl_array(self.y_units_),
 
 
1740
  options=options,
1741
  numprocs=cprocs,
1742
  parallelism=parallelism,
 
1736
  ),
1737
  y_variable_names=jl_y_variable_names,
1738
  X_units=jl_array(self.X_units_),
1739
+ y_units=jl_array(self.y_units_)
1740
+ if isinstance(self.y_units_, list)
1741
+ else self.y_units_,
1742
  options=options,
1743
  numprocs=cprocs,
1744
  parallelism=parallelism,
pysr/test/test.py CHANGED
@@ -1038,7 +1038,7 @@ class TestDimensionalConstraints(unittest.TestCase):
1038
  valid_units = [
1039
  (np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
1040
  (np.ones((10, 1)), np.ones(10), ["m/s"], None),
1041
- (np.ones((10, 1)), np.ones(10), None, "m/s"),
1042
  (np.ones((10, 1)), np.ones(10), None, ["m/s"]),
1043
  (np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
1044
  (np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", ""]),
@@ -1053,7 +1053,7 @@ class TestDimensionalConstraints(unittest.TestCase):
1053
  )
1054
  invalid_units = [
1055
  (np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
1056
- (np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "m"),
1057
  (np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
1058
  (np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
1059
  ]
 
1038
  valid_units = [
1039
  (np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
1040
  (np.ones((10, 1)), np.ones(10), ["m/s"], None),
1041
+ (np.ones((10, 1)), np.ones(10), None, "km/s"),
1042
  (np.ones((10, 1)), np.ones(10), None, ["m/s"]),
1043
  (np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
1044
  (np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", ""]),
 
1053
  )
1054
  invalid_units = [
1055
  (np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
1056
+ (np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "km"),
1057
  (np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
1058
  (np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
1059
  ]