Spaces:
Running
Running
MilesCranmer
commited on
Merge pull request #545 from MilesCranmer/fix-units
Browse files- pyproject.toml +1 -1
- pysr/sr.py +3 -1
- pysr/test/test.py +2 -2
pyproject.toml
CHANGED
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4 |
|
5 |
[project]
|
6 |
name = "pysr"
|
7 |
-
version = "0.17.
|
8 |
authors = [
|
9 |
{name = "Miles Cranmer", email = "[email protected]"},
|
10 |
]
|
|
|
4 |
|
5 |
[project]
|
6 |
name = "pysr"
|
7 |
+
version = "0.17.1"
|
8 |
authors = [
|
9 |
{name = "Miles Cranmer", email = "[email protected]"},
|
10 |
]
|
pysr/sr.py
CHANGED
@@ -1736,7 +1736,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1736 |
),
|
1737 |
y_variable_names=jl_y_variable_names,
|
1738 |
X_units=jl_array(self.X_units_),
|
1739 |
-
y_units=jl_array(self.y_units_)
|
|
|
|
|
1740 |
options=options,
|
1741 |
numprocs=cprocs,
|
1742 |
parallelism=parallelism,
|
|
|
1736 |
),
|
1737 |
y_variable_names=jl_y_variable_names,
|
1738 |
X_units=jl_array(self.X_units_),
|
1739 |
+
y_units=jl_array(self.y_units_)
|
1740 |
+
if isinstance(self.y_units_, list)
|
1741 |
+
else self.y_units_,
|
1742 |
options=options,
|
1743 |
numprocs=cprocs,
|
1744 |
parallelism=parallelism,
|
pysr/test/test.py
CHANGED
@@ -1038,7 +1038,7 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
1038 |
valid_units = [
|
1039 |
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
|
1040 |
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
|
1041 |
-
(np.ones((10, 1)), np.ones(10), None, "
|
1042 |
(np.ones((10, 1)), np.ones(10), None, ["m/s"]),
|
1043 |
(np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
|
1044 |
(np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", ""]),
|
@@ -1053,7 +1053,7 @@ class TestDimensionalConstraints(unittest.TestCase):
|
|
1053 |
)
|
1054 |
invalid_units = [
|
1055 |
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
|
1056 |
-
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "
|
1057 |
(np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
|
1058 |
(np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
|
1059 |
]
|
|
|
1038 |
valid_units = [
|
1039 |
(np.ones((10, 2)), np.ones(10), ["m/s", "s"], "m"),
|
1040 |
(np.ones((10, 1)), np.ones(10), ["m/s"], None),
|
1041 |
+
(np.ones((10, 1)), np.ones(10), None, "km/s"),
|
1042 |
(np.ones((10, 1)), np.ones(10), None, ["m/s"]),
|
1043 |
(np.ones((10, 1)), np.ones((10, 1)), None, ["m/s"]),
|
1044 |
(np.ones((10, 1)), np.ones((10, 2)), None, ["m/s", ""]),
|
|
|
1053 |
)
|
1054 |
invalid_units = [
|
1055 |
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], None),
|
1056 |
+
(np.ones((10, 2)), np.ones(10), ["m/s", "s", "s^2"], "km"),
|
1057 |
(np.ones((10, 2)), np.ones((10, 2)), ["m/s", "s"], ["m"]),
|
1058 |
(np.ones((10, 1)), np.ones((10, 1)), "m/s", ["m"]),
|
1059 |
]
|