MilesCranmer commited on
Commit
10bac39
1 Parent(s): 7a792a8

Fix edgecase when expression is a constant

Browse files
Files changed (1) hide show
  1. pysr/sr.py +4 -3
pysr/sr.py CHANGED
@@ -271,12 +271,13 @@ class CallableEquation:
271
  return f"PySRFunction(X=>{self._sympy})"
272
 
273
  def __call__(self, X):
 
274
  if isinstance(X, pd.DataFrame):
275
  # Lambda function takes as argument:
276
- return self._lambda(**{k: X[k].values for k in X.columns})
277
  elif self._selection is not None:
278
- return self._lambda(*X[:, self._selection].T)
279
- return self._lambda(*X.T)
280
 
281
 
282
  def _get_julia_project(julia_project):
 
271
  return f"PySRFunction(X=>{self._sympy})"
272
 
273
  def __call__(self, X):
274
+ expected_shape = (X.shape[0],)
275
  if isinstance(X, pd.DataFrame):
276
  # Lambda function takes as argument:
277
+ return self._lambda(**{k: X[k].values for k in X.columns}) * np.ones(expected_shape)
278
  elif self._selection is not None:
279
+ return self._lambda(*X[:, self._selection].T) * np.ones(expected_shape)
280
+ return self._lambda(*X.T) * np.ones(expected_shape)
281
 
282
 
283
  def _get_julia_project(julia_project):