MilesCranmer commited on
Commit
fbbe578
1 Parent(s): 887e02d

Add mechanism to manually do model selection

Browse files
Files changed (1) hide show
  1. pysr/sr.py +62 -15
pysr/sr.py CHANGED
@@ -779,10 +779,22 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
779
  **{key: self.__getattribute__(key) for key in self.surface_parameters},
780
  }
781
 
782
- def get_best(self):
783
- """Get best equation using `model_selection`."""
 
 
 
 
 
 
 
 
784
  if self.equations is None:
785
  raise ValueError("No equations have been generated yet.")
 
 
 
 
786
  if self.model_selection == "accuracy":
787
  if isinstance(self.equations, list):
788
  return [eq.iloc[-1] for eq in self.equations]
@@ -826,7 +838,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
826
  # such as extra_sympy_mappings.
827
  self.equations = self.get_hof()
828
 
829
- def predict(self, X):
830
  """Predict y from input X using the equation chosen by `model_selection`.
831
 
832
  You may see what equation is used by printing this object. X should have the same
@@ -834,36 +846,63 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
834
 
835
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
836
  :type X: np.ndarray/pandas.DataFrame
837
- :return: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
 
 
 
 
838
  """
839
  self.refresh()
840
- best = self.get_best()
841
  if self.multioutput:
842
  return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
843
  return best["lambda_format"](X)
844
 
845
- def sympy(self):
846
- """Return sympy representation of the equation(s) chosen by `model_selection`."""
 
 
 
 
 
 
 
847
  self.refresh()
848
- best = self.get_best()
849
  if self.multioutput:
850
  return [eq["sympy_format"] for eq in best]
851
  return best["sympy_format"]
852
 
853
- def latex(self):
854
- """Return latex representation of the equation(s) chosen by `model_selection`."""
 
 
 
 
 
 
 
 
855
  self.refresh()
856
- sympy_representation = self.sympy()
857
  if self.multioutput:
858
  return [sympy.latex(s) for s in sympy_representation]
859
  return sympy.latex(sympy_representation)
860
 
861
- def jax(self):
862
  """Return jax representation of the equation(s) chosen by `model_selection`.
863
 
864
  Each equation (multiple given if there are multiple outputs) is a dictionary
865
  containing {"callable": func, "parameters": params}. To call `func`, pass
866
  func(X, params). This function is differentiable using `jax.grad`.
 
 
 
 
 
 
 
 
867
  """
868
  if self.using_pandas:
869
  warnings.warn(
@@ -873,18 +912,26 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
873
  )
874
  self.set_params(output_jax_format=True)
875
  self.refresh()
876
- best = self.get_best()
877
  if self.multioutput:
878
  return [eq["jax_format"] for eq in best]
879
  return best["jax_format"]
880
 
881
- def pytorch(self):
882
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
883
 
884
  Each equation (multiple given if there are multiple outputs) is a PyTorch module
885
  containing the parameters as trainable attributes. You can use the module like
886
  any other PyTorch module: `module(X)`, where `X` is a tensor with the same
887
  column ordering as trained with.
 
 
 
 
 
 
 
 
888
  """
889
  if self.using_pandas:
890
  warnings.warn(
@@ -894,7 +941,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
894
  )
895
  self.set_params(output_torch_format=True)
896
  self.refresh()
897
- best = self.get_best()
898
  if self.multioutput:
899
  return [eq["torch_format"] for eq in best]
900
  return best["torch_format"]
 
779
  **{key: self.__getattribute__(key) for key in self.surface_parameters},
780
  }
781
 
782
+ def get_best(self, row=None):
783
+ """Get best equation using `model_selection`.
784
+
785
+ :param row: Optional. If you wish to select a particular equation
786
+ from `self.equations`, give the row number here. This overrides
787
+ the `model_selection` parameter.
788
+ :type row: int
789
+ :returns: Dictionary representing the best expression found.
790
+ :type: pd.Series
791
+ """
792
  if self.equations is None:
793
  raise ValueError("No equations have been generated yet.")
794
+
795
+ if row is not None:
796
+ return self.equations.iloc[row]
797
+
798
  if self.model_selection == "accuracy":
799
  if isinstance(self.equations, list):
800
  return [eq.iloc[-1] for eq in self.equations]
 
838
  # such as extra_sympy_mappings.
839
  self.equations = self.get_hof()
840
 
841
+ def predict(self, X, row=None):
842
  """Predict y from input X using the equation chosen by `model_selection`.
843
 
844
  You may see what equation is used by printing this object. X should have the same
 
846
 
847
  :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
848
  :type X: np.ndarray/pandas.DataFrame
849
+ :param row: Optional. If you want to predict an expression using a particular row of
850
+ `self.equations`, you may specify the row here.
851
+ :type row: int
852
+ :returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
853
+ :type: np.ndarray
854
  """
855
  self.refresh()
856
+ best = self.get_best(row=row)
857
  if self.multioutput:
858
  return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
859
  return best["lambda_format"](X)
860
 
861
+ def sympy(self, row=None):
862
+ """Return sympy representation of the equation(s) chosen by `model_selection`.
863
+
864
+ :param row: Optional. If you wish to select a particular equation
865
+ from `self.equations`, give the row number here. This overrides
866
+ the `model_selection` parameter.
867
+ :type row: int
868
+ :returns: SymPy representation of the best expression.
869
+ """
870
  self.refresh()
871
+ best = self.get_best(row=row)
872
  if self.multioutput:
873
  return [eq["sympy_format"] for eq in best]
874
  return best["sympy_format"]
875
 
876
+ def latex(self, row=None):
877
+ """Return latex representation of the equation(s) chosen by `model_selection`.
878
+
879
+ :param row: Optional. If you wish to select a particular equation
880
+ from `self.equations`, give the row number here. This overrides
881
+ the `model_selection` parameter.
882
+ :type row: int
883
+ :returns: LaTeX expression as a string
884
+ :type: str
885
+ """
886
  self.refresh()
887
+ sympy_representation = self.sympy(row=row)
888
  if self.multioutput:
889
  return [sympy.latex(s) for s in sympy_representation]
890
  return sympy.latex(sympy_representation)
891
 
892
+ def jax(self, row=None):
893
  """Return jax representation of the equation(s) chosen by `model_selection`.
894
 
895
  Each equation (multiple given if there are multiple outputs) is a dictionary
896
  containing {"callable": func, "parameters": params}. To call `func`, pass
897
  func(X, params). This function is differentiable using `jax.grad`.
898
+
899
+ :param row: Optional. If you wish to select a particular equation
900
+ from `self.equations`, give the row number here. This overrides
901
+ the `model_selection` parameter.
902
+ :type row: int
903
+ :returns: Dictionary of callable jax function in "callable" key,
904
+ and jax array of parameters as "parameters" key.
905
+ :type: dict
906
  """
907
  if self.using_pandas:
908
  warnings.warn(
 
912
  )
913
  self.set_params(output_jax_format=True)
914
  self.refresh()
915
+ best = self.get_best(row=row)
916
  if self.multioutput:
917
  return [eq["jax_format"] for eq in best]
918
  return best["jax_format"]
919
 
920
+ def pytorch(self, row=None):
921
  """Return pytorch representation of the equation(s) chosen by `model_selection`.
922
 
923
  Each equation (multiple given if there are multiple outputs) is a PyTorch module
924
  containing the parameters as trainable attributes. You can use the module like
925
  any other PyTorch module: `module(X)`, where `X` is a tensor with the same
926
  column ordering as trained with.
927
+
928
+
929
+ :param row: Optional. If you wish to select a particular equation
930
+ from `self.equations`, give the row number here. This overrides
931
+ the `model_selection` parameter.
932
+ :type row: int
933
+ :returns: PyTorch module representing the expression.
934
+ :type: torch.nn.Module
935
  """
936
  if self.using_pandas:
937
  warnings.warn(
 
941
  )
942
  self.set_params(output_torch_format=True)
943
  self.refresh()
944
+ best = self.get_best(row=row)
945
  if self.multioutput:
946
  return [eq["torch_format"] for eq in best]
947
  return best["torch_format"]