MilesCranmer commited on
Commit
518eb85
1 Parent(s): 4839f5f

Change row to index in specifying which expression

Browse files
Files changed (1) hide show
  1. pysr/sr.py +29 -29
pysr/sr.py CHANGED
@@ -779,21 +779,21 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
779
  **{key: self.__getattribute__(key) for key in self.surface_parameters},
780
  }
781
 
782
- def get_best(self, row=None):
783
  """Get best equation using `model_selection`.
784
 
785
- :param row: Optional. If you wish to select a particular equation
786
  from `self.equations`, give the row number here. This overrides
787
  the `model_selection` parameter.
788
- :type row: int
789
  :returns: Dictionary representing the best expression found.
790
  :type: pd.Series
791
  """
792
  if self.equations is None:
793
  raise ValueError("No equations have been generated yet.")
794
 
795
- if row is not None:
796
- return self.equations.iloc[row]
797
 
798
  if self.model_selection == "accuracy":
799
  if isinstance(self.equations, list):
@@ -838,7 +838,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
838
  # such as extra_sympy_mappings.
839
  self.equations = self.get_hof()
840
 
841
- def predict(self, X, row=None):
842
  """Predict y from input X using the equation chosen by `model_selection`.
843
 
844
  You may see what equation is used by printing this object. X should have the same
@@ -846,60 +846,60 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
846
 
847
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
848
  :type X: np.ndarray/pandas.DataFrame
849
- :param row: Optional. If you want to predict an expression using a particular row of
850
- `self.equations`, you may specify the row here.
851
- :type row: int
852
  :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
853
  :type: np.ndarray
854
  """
855
  self.refresh()
856
- best = self.get_best(row=row)
857
  if self.multioutput:
858
  return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
859
  return best["lambda_format"](X)
860
 
861
- def sympy(self, row=None):
862
  """Return sympy representation of the equation(s) chosen by `model_selection`.
863
 
864
- :param row: Optional. If you wish to select a particular equation
865
- from `self.equations`, give the row number here. This overrides
866
  the `model_selection` parameter.
867
- :type row: int
868
  :returns: SymPy representation of the best expression.
869
  """
870
  self.refresh()
871
- best = self.get_best(row=row)
872
  if self.multioutput:
873
  return [eq["sympy_format"] for eq in best]
874
  return best["sympy_format"]
875
 
876
- def latex(self, row=None):
877
  """Return latex representation of the equation(s) chosen by `model_selection`.
878
 
879
- :param row: Optional. If you wish to select a particular equation
880
- from `self.equations`, give the row number here. This overrides
881
  the `model_selection` parameter.
882
- :type row: int
883
  :returns: LaTeX expression as a string
884
  :type: str
885
  """
886
  self.refresh()
887
- sympy_representation = self.sympy(row=row)
888
  if self.multioutput:
889
  return [sympy.latex(s) for s in sympy_representation]
890
  return sympy.latex(sympy_representation)
891
 
892
- def jax(self, row=None):
893
  """Return jax representation of the equation(s) chosen by `model_selection`.
894
 
895
  Each equation (multiple given if there are multiple outputs) is a dictionary
896
  containing {"callable": func, "parameters": params}. To call `func`, pass
897
  func(X, params). This function is differentiable using `jax.grad`.
898
 
899
- :param row: Optional. If you wish to select a particular equation
900
- from `self.equations`, give the row number here. This overrides
901
  the `model_selection` parameter.
902
- :type row: int
903
  :returns: Dictionary of callable jax function in "callable" key,
904
  and jax array of parameters as "parameters" key.
905
  :type: dict
@@ -912,12 +912,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
912
  )
913
  self.set_params(output_jax_format=True)
914
  self.refresh()
915
- best = self.get_best(row=row)
916
  if self.multioutput:
917
  return [eq["jax_format"] for eq in best]
918
  return best["jax_format"]
919
 
920
- def pytorch(self, row=None):
921
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
922
 
923
  Each equation (multiple given if there are multiple outputs) is a PyTorch module
@@ -926,10 +926,10 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
926
  column ordering as trained with.
927
 
928
 
929
- :param row: Optional. If you wish to select a particular equation
930
  from `self.equations`, give the row number here. This overrides
931
  the `model_selection` parameter.
932
- :type row: int
933
  :returns: PyTorch module representing the expression.
934
  :type: torch.nn.Module
935
  """
@@ -941,7 +941,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
941
  )
942
  self.set_params(output_torch_format=True)
943
  self.refresh()
944
- best = self.get_best(row=row)
945
  if self.multioutput:
946
  return [eq["torch_format"] for eq in best]
947
  return best["torch_format"]
 
779
  **{key: self.__getattribute__(key) for key in self.surface_parameters},
780
  }
781
 
782
+ def get_best(self, index=None):
783
  """Get best equation using `model_selection`.
784
 
785
+ :param index: Optional. If you wish to select a particular equation
786
  from `self.equations`, give the row number here. This overrides
787
  the `model_selection` parameter.
788
+ :type index: int
789
  :returns: Dictionary representing the best expression found.
790
  :type: pd.Series
791
  """
792
  if self.equations is None:
793
  raise ValueError("No equations have been generated yet.")
794
 
795
+ if index is not None:
796
+ return self.equations.iloc[index]
797
 
798
  if self.model_selection == "accuracy":
799
  if isinstance(self.equations, list):
 
838
  # such as extra_sympy_mappings.
839
  self.equations = self.get_hof()
840
 
841
+ def predict(self, X, index=None):
842
  """Predict y from input X using the equation chosen by `model_selection`.
843
 
844
  You may see what equation is used by printing this object. X should have the same
 
846
 
847
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
848
  :type X: np.ndarray/pandas.DataFrame
849
+ :param index: Optional. If you want to predict an expression using a particular row of
850
+ `self.equations`, you may specify the index here.
851
+ :type index: int
852
  :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
853
  :type: np.ndarray
854
  """
855
  self.refresh()
856
+ best = self.get_best(index=index)
857
  if self.multioutput:
858
  return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
859
  return best["lambda_format"](X)
860
 
861
+ def sympy(self, index=None):
862
  """Return sympy representation of the equation(s) chosen by `model_selection`.
863
 
864
+ :param index: Optional. If you wish to select a particular equation
865
+ from `self.equations`, give the index number here. This overrides
866
  the `model_selection` parameter.
867
+ :type index: int
868
  :returns: SymPy representation of the best expression.
869
  """
870
  self.refresh()
871
+ best = self.get_best(index=index)
872
  if self.multioutput:
873
  return [eq["sympy_format"] for eq in best]
874
  return best["sympy_format"]
875
 
876
+ def latex(self, index=None):
877
  """Return latex representation of the equation(s) chosen by `model_selection`.
878
 
879
+ :param index: Optional. If you wish to select a particular equation
880
+ from `self.equations`, give the index number here. This overrides
881
  the `model_selection` parameter.
882
+ :type index: int
883
  :returns: LaTeX expression as a string
884
  :type: str
885
  """
886
  self.refresh()
887
+ sympy_representation = self.sympy(index=index)
888
  if self.multioutput:
889
  return [sympy.latex(s) for s in sympy_representation]
890
  return sympy.latex(sympy_representation)
891
 
892
+ def jax(self, index=None):
893
  """Return jax representation of the equation(s) chosen by `model_selection`.
894
 
895
  Each equation (multiple given if there are multiple outputs) is a dictionary
896
  containing {"callable": func, "parameters": params}. To call `func`, pass
897
  func(X, params). This function is differentiable using `jax.grad`.
898
 
899
+ :param index: Optional. If you wish to select a particular equation
900
+ from `self.equations`, give the index number here. This overrides
901
  the `model_selection` parameter.
902
+ :type index: int
903
  :returns: Dictionary of callable jax function in "callable" key,
904
  and jax array of parameters as "parameters" key.
905
  :type: dict
 
912
  )
913
  self.set_params(output_jax_format=True)
914
  self.refresh()
915
+ best = self.get_best(index=index)
916
  if self.multioutput:
917
  return [eq["jax_format"] for eq in best]
918
  return best["jax_format"]
919
 
920
+ def pytorch(self, index=None):
921
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
922
 
923
  Each equation (multiple given if there are multiple outputs) is a PyTorch module
 
926
  column ordering as trained with.
927
 
928
 
929
+ :param index: Optional. If you wish to select a particular equation
930
  from `self.equations`, give the row number here. This overrides
931
  the `model_selection` parameter.
932
+ :type index: int
933
  :returns: PyTorch module representing the expression.
934
  :type: torch.nn.Module
935
  """
 
941
  )
942
  self.set_params(output_torch_format=True)
943
  self.refresh()
944
+ best = self.get_best(index=index)
945
  if self.multioutput:
946
  return [eq["torch_format"] for eq in best]
947
  return best["torch_format"]