Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
3da0df5
1
Parent(s):
03d5a42
Fix pickling for multi-output
Browse files- pysr/sr.py +18 -6
pysr/sr.py
CHANGED
@@ -884,12 +884,24 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
884 |
if "equations_" in pickled_state:
|
885 |
pickled_state["output_torch_format"] = False
|
886 |
pickled_state["output_jax_format"] = False
|
887 |
-
|
888 |
-
["
|
889 |
-
|
890 |
-
|
891 |
-
pickled_state["equations_"]
|
892 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
893 |
return pickled_state
|
894 |
|
895 |
@property
|
|
|
884 |
if "equations_" in pickled_state:
|
885 |
pickled_state["output_torch_format"] = False
|
886 |
pickled_state["output_jax_format"] = False
|
887 |
+
if self.nout_ == 1:
|
888 |
+
pickled_columns = ~pickled_state["equations_"].columns.isin(
|
889 |
+
["jax_format", "torch_format"]
|
890 |
+
)
|
891 |
+
pickled_state["equations_"] = (
|
892 |
+
pickled_state["equations_"].loc[:, pickled_columns].copy()
|
893 |
+
)
|
894 |
+
else:
|
895 |
+
pickled_columns = [
|
896 |
+
~dataframe.columns.isin(["jax_format", "torch_format"])
|
897 |
+
for dataframe in pickled_state["equations_"]
|
898 |
+
]
|
899 |
+
pickled_state["equations_"] = [
|
900 |
+
dataframe.loc[:, signle_pickled_columns]
|
901 |
+
for dataframe, signle_pickled_columns in zip(
|
902 |
+
pickled_state["equations_"], pickled_columns
|
903 |
+
)
|
904 |
+
]
|
905 |
return pickled_state
|
906 |
|
907 |
@property
|