MilesCranmer commited on
Commit
35ca811
1 Parent(s): 518eb85

Make manual selection work for multi-output

Browse files
Files changed (1) hide show
  1. pysr/sr.py +5 -1
pysr/sr.py CHANGED
@@ -793,6 +793,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
793
  raise ValueError("No equations have been generated yet.")
794
 
795
  if index is not None:
 
 
 
796
  return self.equations.iloc[index]
797
 
798
  if self.model_selection == "accuracy":
@@ -846,7 +849,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
846
 
847
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
848
  :type X: np.ndarray/pandas.DataFrame
849
- :param index: Optional. If you want to predict an expression using a particular row of
 
850
  `self.equations`, you may specify the index here.
851
  :type index: int
852
  :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
 
793
  raise ValueError("No equations have been generated yet.")
794
 
795
  if index is not None:
796
+ if isinstance(self.equations, list):
797
+ assert isinstance(index, list)
798
+ return [self.equations.iloc[i] for i in index]
799
  return self.equations.iloc[index]
800
 
801
  if self.model_selection == "accuracy":
 
849
 
850
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
851
  :type X: np.ndarray/pandas.DataFrame
852
+ :param index: Optional. If you want to compute the output of
853
+ an expression using a particular row of
854
  `self.equations`, you may specify the index here.
855
  :type index: int
856
  :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).