MilesCranmer commited on
Commit
3da0df5
1 Parent(s): 03d5a42

Fix pickling for multi-output

Browse files
Files changed (1) hide show
  1. pysr/sr.py +18 -6
pysr/sr.py CHANGED
@@ -884,12 +884,24 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
884
  if "equations_" in pickled_state:
885
  pickled_state["output_torch_format"] = False
886
  pickled_state["output_jax_format"] = False
887
- pickled_columns = ~pickled_state["equations_"].columns.isin(
888
- ["jax_format", "torch_format"]
889
- )
890
- pickled_state["equations_"] = (
891
- pickled_state["equations_"].loc[:, pickled_columns].copy()
892
- )
 
 
 
 
 
 
 
 
 
 
 
 
893
  return pickled_state
894
 
895
  @property
 
884
  if "equations_" in pickled_state:
885
  pickled_state["output_torch_format"] = False
886
  pickled_state["output_jax_format"] = False
887
+ if self.nout_ == 1:
888
+ pickled_columns = ~pickled_state["equations_"].columns.isin(
889
+ ["jax_format", "torch_format"]
890
+ )
891
+ pickled_state["equations_"] = (
892
+ pickled_state["equations_"].loc[:, pickled_columns].copy()
893
+ )
894
+ else:
895
+ pickled_columns = [
896
+ ~dataframe.columns.isin(["jax_format", "torch_format"])
897
+ for dataframe in pickled_state["equations_"]
898
+ ]
899
+ pickled_state["equations_"] = [
900
+ dataframe.loc[:, signle_pickled_columns]
901
+ for dataframe, signle_pickled_columns in zip(
902
+ pickled_state["equations_"], pickled_columns
903
+ )
904
+ ]
905
  return pickled_state
906
 
907
  @property