MilesCranmer commited on
Commit
7113eed
1 Parent(s): 7909e90

style: use pandas indexing for return values

Browse files
Files changed (1) hide show
  1. pysr/sr.py +7 -5
pysr/sr.py CHANGED
@@ -47,6 +47,7 @@ from .julia_helpers import (
47
  )
48
  from .julia_import import SymbolicRegression, jl
49
  from .utils import (
 
50
  _csv_filename_to_pkl_filename,
51
  _preprocess_julia_floats,
52
  _safe_check_feature_names_in,
@@ -1037,7 +1038,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1037
  all_equations = equations
1038
 
1039
  for i, equations in enumerate(all_equations):
1040
- selected = ["" for _ in range(len(equations))]
1041
  chosen_row = idx_model_selection(equations, self.model_selection)
1042
  selected[chosen_row] = ">>>>"
1043
  repr_equations = pd.DataFrame(
@@ -1191,12 +1192,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1191
 
1192
  if isinstance(self.equations_, list):
1193
  return [
1194
- eq.iloc[idx_model_selection(eq, self.model_selection)]
1195
  for eq in self.equations_
1196
  ]
1197
- return self.equations_.iloc[
1198
- idx_model_selection(self.equations_, self.model_selection)
1199
- ]
 
1200
 
1201
  def _setup_equation_file(self):
1202
  """
 
47
  )
48
  from .julia_import import SymbolicRegression, jl
49
  from .utils import (
50
+ ArrayLike,
51
  _csv_filename_to_pkl_filename,
52
  _preprocess_julia_floats,
53
  _safe_check_feature_names_in,
 
1038
  all_equations = equations
1039
 
1040
  for i, equations in enumerate(all_equations):
1041
+ selected = pd.Series([""] * len(equations), index=equations.index)
1042
  chosen_row = idx_model_selection(equations, self.model_selection)
1043
  selected[chosen_row] = ">>>>"
1044
  repr_equations = pd.DataFrame(
 
1192
 
1193
  if isinstance(self.equations_, list):
1194
  return [
1195
+ eq.loc[idx_model_selection(eq, self.model_selection)]
1196
  for eq in self.equations_
1197
  ]
1198
+ else:
1199
+ return self.equations_.loc[
1200
+ idx_model_selection(self.equations_, self.model_selection)
1201
+ ]
1202
 
1203
  def _setup_equation_file(self):
1204
  """