Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
518eb85
1
Parent(s):
4839f5f
Change row to index in specifying which expression
Browse files- pysr/sr.py +29 -29
pysr/sr.py
CHANGED
@@ -779,21 +779,21 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
779 |
**{key: self.__getattribute__(key) for key in self.surface_parameters},
|
780 |
}
|
781 |
|
782 |
-
def get_best(self,
|
783 |
"""Get best equation using `model_selection`.
|
784 |
|
785 |
-
:param
|
786 |
from `self.equations`, give the row number here. This overrides
|
787 |
the `model_selection` parameter.
|
788 |
-
:type
|
789 |
:returns: Dictionary representing the best expression found.
|
790 |
:type: pd.Series
|
791 |
"""
|
792 |
if self.equations is None:
|
793 |
raise ValueError("No equations have been generated yet.")
|
794 |
|
795 |
-
if
|
796 |
-
return self.equations.iloc[
|
797 |
|
798 |
if self.model_selection == "accuracy":
|
799 |
if isinstance(self.equations, list):
|
@@ -838,7 +838,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
838 |
# such as extra_sympy_mappings.
|
839 |
self.equations = self.get_hof()
|
840 |
|
841 |
-
def predict(self, X,
|
842 |
"""Predict y from input X using the equation chosen by `model_selection`.
|
843 |
|
844 |
You may see what equation is used by printing this object. X should have the same
|
@@ -846,60 +846,60 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
846 |
|
847 |
:param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
|
848 |
:type X: np.ndarray/pandas.DataFrame
|
849 |
-
:param
|
850 |
-
`self.equations`, you may specify the
|
851 |
-
:type
|
852 |
:returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
|
853 |
:type: np.ndarray
|
854 |
"""
|
855 |
self.refresh()
|
856 |
-
best = self.get_best(
|
857 |
if self.multioutput:
|
858 |
return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
|
859 |
return best["lambda_format"](X)
|
860 |
|
861 |
-
def sympy(self,
|
862 |
"""Return sympy representation of the equation(s) chosen by `model_selection`.
|
863 |
|
864 |
-
:param
|
865 |
-
from `self.equations`, give the
|
866 |
the `model_selection` parameter.
|
867 |
-
:type
|
868 |
:returns: SymPy representation of the best expression.
|
869 |
"""
|
870 |
self.refresh()
|
871 |
-
best = self.get_best(
|
872 |
if self.multioutput:
|
873 |
return [eq["sympy_format"] for eq in best]
|
874 |
return best["sympy_format"]
|
875 |
|
876 |
-
def latex(self,
|
877 |
"""Return latex representation of the equation(s) chosen by `model_selection`.
|
878 |
|
879 |
-
:param
|
880 |
-
from `self.equations`, give the
|
881 |
the `model_selection` parameter.
|
882 |
-
:type
|
883 |
:returns: LaTeX expression as a string
|
884 |
:type: str
|
885 |
"""
|
886 |
self.refresh()
|
887 |
-
sympy_representation = self.sympy(
|
888 |
if self.multioutput:
|
889 |
return [sympy.latex(s) for s in sympy_representation]
|
890 |
return sympy.latex(sympy_representation)
|
891 |
|
892 |
-
def jax(self,
|
893 |
"""Return jax representation of the equation(s) chosen by `model_selection`.
|
894 |
|
895 |
Each equation (multiple given if there are multiple outputs) is a dictionary
|
896 |
containing {"callable": func, "parameters": params}. To call `func`, pass
|
897 |
func(X, params). This function is differentiable using `jax.grad`.
|
898 |
|
899 |
-
:param
|
900 |
-
from `self.equations`, give the
|
901 |
the `model_selection` parameter.
|
902 |
-
:type
|
903 |
:returns: Dictionary of callable jax function in "callable" key,
|
904 |
and jax array of parameters as "parameters" key.
|
905 |
:type: dict
|
@@ -912,12 +912,12 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
912 |
)
|
913 |
self.set_params(output_jax_format=True)
|
914 |
self.refresh()
|
915 |
-
best = self.get_best(
|
916 |
if self.multioutput:
|
917 |
return [eq["jax_format"] for eq in best]
|
918 |
return best["jax_format"]
|
919 |
|
920 |
-
def pytorch(self,
|
921 |
"""Return pytorch representation of the equation(s) chosen by `model_selection`.
|
922 |
|
923 |
Each equation (multiple given if there are multiple outputs) is a PyTorch module
|
@@ -926,10 +926,10 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
926 |
column ordering as trained with.
|
927 |
|
928 |
|
929 |
-
:param
|
930 |
from `self.equations`, give the row number here. This overrides
|
931 |
the `model_selection` parameter.
|
932 |
-
:type
|
933 |
:returns: PyTorch module representing the expression.
|
934 |
:type: torch.nn.Module
|
935 |
"""
|
@@ -941,7 +941,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
941 |
)
|
942 |
self.set_params(output_torch_format=True)
|
943 |
self.refresh()
|
944 |
-
best = self.get_best(
|
945 |
if self.multioutput:
|
946 |
return [eq["torch_format"] for eq in best]
|
947 |
return best["torch_format"]
|
|
|
779 |
**{key: self.__getattribute__(key) for key in self.surface_parameters},
|
780 |
}
|
781 |
|
782 |
+
def get_best(self, index=None):
|
783 |
"""Get best equation using `model_selection`.
|
784 |
|
785 |
+
:param index: Optional. If you wish to select a particular equation
|
786 |
from `self.equations`, give the row number here. This overrides
|
787 |
the `model_selection` parameter.
|
788 |
+
:type index: int
|
789 |
:returns: Dictionary representing the best expression found.
|
790 |
:type: pd.Series
|
791 |
"""
|
792 |
if self.equations is None:
|
793 |
raise ValueError("No equations have been generated yet.")
|
794 |
|
795 |
+
if index is not None:
|
796 |
+
return self.equations.iloc[index]
|
797 |
|
798 |
if self.model_selection == "accuracy":
|
799 |
if isinstance(self.equations, list):
|
|
|
838 |
# such as extra_sympy_mappings.
|
839 |
self.equations = self.get_hof()
|
840 |
|
841 |
+
def predict(self, X, index=None):
|
842 |
"""Predict y from input X using the equation chosen by `model_selection`.
|
843 |
|
844 |
You may see what equation is used by printing this object. X should have the same
|
|
|
846 |
|
847 |
:param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
|
848 |
:type X: np.ndarray/pandas.DataFrame
|
849 |
+
:param index: Optional. If you want to predict an expression using a particular row of
|
850 |
+
`self.equations`, you may specify the index here.
|
851 |
+
:type index: int
|
852 |
:returns: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs).
|
853 |
:type: np.ndarray
|
854 |
"""
|
855 |
self.refresh()
|
856 |
+
best = self.get_best(index=index)
|
857 |
if self.multioutput:
|
858 |
return np.stack([eq["lambda_format"](X) for eq in best], axis=1)
|
859 |
return best["lambda_format"](X)
|
860 |
|
861 |
+
def sympy(self, index=None):
|
862 |
"""Return sympy representation of the equation(s) chosen by `model_selection`.
|
863 |
|
864 |
+
:param index: Optional. If you wish to select a particular equation
|
865 |
+
from `self.equations`, give the index number here. This overrides
|
866 |
the `model_selection` parameter.
|
867 |
+
:type index: int
|
868 |
:returns: SymPy representation of the best expression.
|
869 |
"""
|
870 |
self.refresh()
|
871 |
+
best = self.get_best(index=index)
|
872 |
if self.multioutput:
|
873 |
return [eq["sympy_format"] for eq in best]
|
874 |
return best["sympy_format"]
|
875 |
|
876 |
+
def latex(self, index=None):
|
877 |
"""Return latex representation of the equation(s) chosen by `model_selection`.
|
878 |
|
879 |
+
:param index: Optional. If you wish to select a particular equation
|
880 |
+
from `self.equations`, give the index number here. This overrides
|
881 |
the `model_selection` parameter.
|
882 |
+
:type index: int
|
883 |
:returns: LaTeX expression as a string
|
884 |
:type: str
|
885 |
"""
|
886 |
self.refresh()
|
887 |
+
sympy_representation = self.sympy(index=index)
|
888 |
if self.multioutput:
|
889 |
return [sympy.latex(s) for s in sympy_representation]
|
890 |
return sympy.latex(sympy_representation)
|
891 |
|
892 |
+
def jax(self, index=None):
|
893 |
"""Return jax representation of the equation(s) chosen by `model_selection`.
|
894 |
|
895 |
Each equation (multiple given if there are multiple outputs) is a dictionary
|
896 |
containing {"callable": func, "parameters": params}. To call `func`, pass
|
897 |
func(X, params). This function is differentiable using `jax.grad`.
|
898 |
|
899 |
+
:param index: Optional. If you wish to select a particular equation
|
900 |
+
from `self.equations`, give the index number here. This overrides
|
901 |
the `model_selection` parameter.
|
902 |
+
:type index: int
|
903 |
:returns: Dictionary of callable jax function in "callable" key,
|
904 |
and jax array of parameters as "parameters" key.
|
905 |
:type: dict
|
|
|
912 |
)
|
913 |
self.set_params(output_jax_format=True)
|
914 |
self.refresh()
|
915 |
+
best = self.get_best(index=index)
|
916 |
if self.multioutput:
|
917 |
return [eq["jax_format"] for eq in best]
|
918 |
return best["jax_format"]
|
919 |
|
920 |
+
def pytorch(self, index=None):
|
921 |
"""Return pytorch representation of the equation(s) chosen by `model_selection`.
|
922 |
|
923 |
Each equation (multiple given if there are multiple outputs) is a PyTorch module
|
|
|
926 |
column ordering as trained with.
|
927 |
|
928 |
|
929 |
+
:param index: Optional. If you wish to select a particular equation
|
930 |
from `self.equations`, give the row number here. This overrides
|
931 |
the `model_selection` parameter.
|
932 |
+
:type index: int
|
933 |
:returns: PyTorch module representing the expression.
|
934 |
:type: torch.nn.Module
|
935 |
"""
|
|
|
941 |
)
|
942 |
self.set_params(output_torch_format=True)
|
943 |
self.refresh()
|
944 |
+
best = self.get_best(index=index)
|
945 |
if self.multioutput:
|
946 |
return [eq["torch_format"] for eq in best]
|
947 |
return best["torch_format"]
|