MilesCranmer commited on
Commit
85d18bf
1 Parent(s): 5af7e2e

Allow for pandas array and automatic variable names

Browse files
Files changed (1) hide show
  1. pysr/sr.py +7 -1
pysr/sr.py CHANGED
@@ -84,7 +84,9 @@ def pysr(X=None, y=None, weights=None,
84
  equations, but you should adjust `threads`, `niterations`,
85
  `binary_operators`, `unary_operators` to your requirements.
86
 
87
- :param X: np.ndarray, 2D array. Rows are examples, columns are features.
 
 
88
  :param y: np.ndarray, 1D array. Rows are examples.
89
  :param weights: np.ndarray, 1D array. Each row is how to weight the
90
  mean-square-error loss on weights.
@@ -148,6 +150,10 @@ def pysr(X=None, y=None, weights=None,
148
  if maxdepth is None:
149
  maxdepth = maxsize
150
 
 
 
 
 
151
  # Check for potential errors before they happen
152
  assert len(unary_operators) + len(binary_operators) > 0
153
  assert len(X.shape) == 2
 
84
  equations, but you should adjust `threads`, `niterations`,
85
  `binary_operators`, `unary_operators` to your requirements.
86
 
87
+ :param X: np.ndarray or pandas.DataFrame, 2D array. Rows are examples,
88
+ columns are features. If pandas DataFrame, the columns are used
89
+ for variable names (so make sure they don't contain spaces).
90
  :param y: np.ndarray, 1D array. Rows are examples.
91
  :param weights: np.ndarray, 1D array. Each row is how to weight the
92
  mean-square-error loss on weights.
 
150
  if maxdepth is None:
151
  maxdepth = maxsize
152
 
153
+ if isinstance(X, pd.DataFrame):
154
+ variable_names = list(X.columns)
155
+ X = np.array(X)
156
+
157
  # Check for potential errors before they happen
158
  assert len(unary_operators) + len(binary_operators) > 0
159
  assert len(X.shape) == 2