MilesCranmer commited on
Commit
a5eaab9
1 Parent(s): 04aa23c

refactor: help with type inference of `get_best`

Browse files
Files changed (1) hide show
  1. pysr/sr.py +3 -3
pysr/sr.py CHANGED
@@ -2169,10 +2169,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2169
  self.set_params(output_torch_format=True)
2170
  self.refresh()
2171
  best_equation = self.get_best(index=index)
2172
- if isinstance(best_equation, pd.Series):
2173
- return best_equation["torch_format"]
2174
- else:
2175
  return [eq["torch_format"] for eq in best_equation]
 
 
2176
 
2177
  def _read_equation_file(self):
2178
  """Read the hall of fame file created by `SymbolicRegression.jl`."""
 
2169
  self.set_params(output_torch_format=True)
2170
  self.refresh()
2171
  best_equation = self.get_best(index=index)
2172
+ if isinstance(best_equation, list):
 
 
2173
  return [eq["torch_format"] for eq in best_equation]
2174
+ else:
2175
+ return best_equation["torch_format"]
2176
 
2177
  def _read_equation_file(self):
2178
  """Read the hall of fame file created by `SymbolicRegression.jl`."""