MilesCranmer commited on
Commit
ad8332d
1 Parent(s): 887e02d

Allow y to be a pandas dataframe or series

Browse files
Files changed (2) hide show
  1. pysr/sr.py +2 -1
  2. test/test.py +2 -0
pysr/sr.py CHANGED
@@ -957,7 +957,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
957
  if len(X.shape) == 1:
958
  X = X[:, None]
959
 
960
- assert not isinstance(y, pd.DataFrame)
 
961
 
962
  if variable_names is None or len(variable_names) == 0:
963
  variable_names = [f"x{i}" for i in range(X.shape[1])]
 
957
  if len(X.shape) == 1:
958
  X = X[:, None]
959
 
960
+ if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series):
961
+ y = np.array(y)
962
 
963
  if variable_names is None or len(variable_names) == 0:
964
  variable_names = [f"x{i}" for i in range(X.shape[1])]
test/test.py CHANGED
@@ -143,6 +143,8 @@ class TestPipeline(unittest.TestCase):
143
  y = true_fn(X)
144
  noise = np.random.randn(500) * 0.01
145
  y = y + noise
 
 
146
  # Resampled array is a different order of features:
147
  Xresampled = pd.DataFrame(
148
  {
 
143
  y = true_fn(X)
144
  noise = np.random.randn(500) * 0.01
145
  y = y + noise
146
+ # We also test y as a pandas array:
147
+ y = pd.Series(y)
148
  # Resampled array is a different order of features:
149
  Xresampled = pd.DataFrame(
150
  {