Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
34f4e3f
1
Parent(s):
f5577ea
Change model load to classmethod
Browse files- pysr/__init__.py +0 -1
- pysr/sr.py +113 -104
- test/test.py +4 -4
pysr/__init__.py
CHANGED
@@ -6,7 +6,6 @@ from .sr import (
|
|
6 |
best_tex,
|
7 |
best_callable,
|
8 |
best_row,
|
9 |
-
load,
|
10 |
)
|
11 |
from .julia_helpers import install
|
12 |
from .feynman_problems import Problem, FeynmanProblem
|
|
|
6 |
best_tex,
|
7 |
best_callable,
|
8 |
best_row,
|
|
|
9 |
)
|
10 |
from .julia_helpers import install
|
11 |
from .feynman_problems import Problem, FeynmanProblem
|
pysr/sr.py
CHANGED
@@ -810,6 +810,119 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
810 |
f"{k} is not a valid keyword argument for PySRRegressor."
|
811 |
)
|
812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
def __repr__(self):
|
814 |
"""
|
815 |
Prints all current equations fitted by the model.
|
@@ -2091,107 +2204,3 @@ def _csv_filename_to_pkl_filename(csv_filename) -> str:
|
|
2091 |
pkl_basename = base + ".pkl"
|
2092 |
|
2093 |
return os.path.join(dirname, pkl_basename)
|
2094 |
-
|
2095 |
-
|
2096 |
-
def load(
|
2097 |
-
equation_file,
|
2098 |
-
*,
|
2099 |
-
binary_operators=None,
|
2100 |
-
unary_operators=None,
|
2101 |
-
n_features_in=None,
|
2102 |
-
feature_names_in=None,
|
2103 |
-
selection_mask=None,
|
2104 |
-
nout=1,
|
2105 |
-
**pysr_kwargs,
|
2106 |
-
):
|
2107 |
-
"""
|
2108 |
-
Create a model from equations stored as a csv file
|
2109 |
-
|
2110 |
-
Parameters
|
2111 |
-
----------
|
2112 |
-
equation_file : str
|
2113 |
-
Path to a csv file containing equations, or a pickle file
|
2114 |
-
containing the model.
|
2115 |
-
|
2116 |
-
binary_operators : list[str], default=["+", "-", "*", "/"]
|
2117 |
-
The same binary operators used when creating the model.
|
2118 |
-
|
2119 |
-
unary_operators : list[str], default=None
|
2120 |
-
The same unary operators used when creating the model.
|
2121 |
-
|
2122 |
-
n_features_in : int
|
2123 |
-
Number of features passed to the model.
|
2124 |
-
|
2125 |
-
feature_names_in : list[str], default=None
|
2126 |
-
Names of the features passed to the model.
|
2127 |
-
|
2128 |
-
selection_mask : list[bool], default=None
|
2129 |
-
If using select_k_features, you must pass `model.selection_mask_` here.
|
2130 |
-
|
2131 |
-
nout : int, default=1
|
2132 |
-
Number of outputs of the model.
|
2133 |
-
|
2134 |
-
pysr_kwargs : dict
|
2135 |
-
Any other keyword arguments to initialize the PySRRegressor object.
|
2136 |
-
These will overwrite those stored in the pickle file.
|
2137 |
-
|
2138 |
-
Returns
|
2139 |
-
-------
|
2140 |
-
model : PySRRegressor
|
2141 |
-
The model with fitted equations.
|
2142 |
-
"""
|
2143 |
-
if os.path.splitext(equation_file)[1] != ".pkl":
|
2144 |
-
pkl_filename = _csv_filename_to_pkl_filename(equation_file)
|
2145 |
-
else:
|
2146 |
-
pkl_filename = equation_file
|
2147 |
-
|
2148 |
-
# Try to load model from <equation_file>.pkl
|
2149 |
-
print(f"Checking if {pkl_filename} exists...")
|
2150 |
-
if os.path.exists(pkl_filename):
|
2151 |
-
print(f"Loading model from {pkl_filename}")
|
2152 |
-
assert binary_operators is None
|
2153 |
-
assert unary_operators is None
|
2154 |
-
assert n_features_in is None
|
2155 |
-
with open(pkl_filename, "rb") as f:
|
2156 |
-
model = pkl.load(f)
|
2157 |
-
# Update any parameters if necessary, such as
|
2158 |
-
# extra_sympy_mappings:
|
2159 |
-
model.set_params(**pysr_kwargs)
|
2160 |
-
if "equations_" not in model.__dict__ or model.equations_ is None:
|
2161 |
-
model.refresh()
|
2162 |
-
|
2163 |
-
return model
|
2164 |
-
|
2165 |
-
# Else, we re-create it.
|
2166 |
-
print(
|
2167 |
-
f"{equation_file} does not exist, " "so we must create the model from scratch."
|
2168 |
-
)
|
2169 |
-
assert binary_operators is not None
|
2170 |
-
assert unary_operators is not None
|
2171 |
-
assert n_features_in is not None
|
2172 |
-
|
2173 |
-
# TODO: copy .bkup file if exists.
|
2174 |
-
model = PySRRegressor(
|
2175 |
-
equation_file=equation_file,
|
2176 |
-
binary_operators=binary_operators,
|
2177 |
-
unary_operators=unary_operators,
|
2178 |
-
**pysr_kwargs,
|
2179 |
-
)
|
2180 |
-
|
2181 |
-
model.nout_ = nout
|
2182 |
-
model.n_features_in_ = n_features_in
|
2183 |
-
|
2184 |
-
if feature_names_in is None:
|
2185 |
-
model.feature_names_in_ = [f"x{i}" for i in range(n_features_in)]
|
2186 |
-
else:
|
2187 |
-
assert len(feature_names_in) == n_features_in
|
2188 |
-
model.feature_names_in_ = feature_names_in
|
2189 |
-
|
2190 |
-
if selection_mask is None:
|
2191 |
-
model.selection_mask_ = np.ones(n_features_in, dtype=bool)
|
2192 |
-
else:
|
2193 |
-
model.selection_mask_ = selection_mask
|
2194 |
-
|
2195 |
-
model.refresh(checkpoint_file=equation_file)
|
2196 |
-
|
2197 |
-
return model
|
|
|
810 |
f"{k} is not a valid keyword argument for PySRRegressor."
|
811 |
)
|
812 |
|
813 |
+
@classmethod
|
814 |
+
def from_file(
|
815 |
+
cls,
|
816 |
+
equation_file,
|
817 |
+
*,
|
818 |
+
binary_operators=None,
|
819 |
+
unary_operators=None,
|
820 |
+
n_features_in=None,
|
821 |
+
feature_names_in=None,
|
822 |
+
selection_mask=None,
|
823 |
+
nout=1,
|
824 |
+
**pysr_kwargs,
|
825 |
+
):
|
826 |
+
"""
|
827 |
+
Create a model from a saved model checkpoint or equation file.
|
828 |
+
|
829 |
+
Parameters
|
830 |
+
----------
|
831 |
+
equation_file : str
|
832 |
+
Path to a pickle file containing a saved model, or a csv file
|
833 |
+
containing equations.
|
834 |
+
|
835 |
+
binary_operators : list[str]
|
836 |
+
The same binary operators used when creating the model.
|
837 |
+
Not needed if loading from a pickle file.
|
838 |
+
|
839 |
+
unary_operators : list[str]
|
840 |
+
The same unary operators used when creating the model.
|
841 |
+
Not needed if loading from a pickle file.
|
842 |
+
|
843 |
+
n_features_in : int
|
844 |
+
Number of features passed to the model.
|
845 |
+
Not needed if loading from a pickle file.
|
846 |
+
|
847 |
+
feature_names_in : list[str]
|
848 |
+
Names of the features passed to the model.
|
849 |
+
Not needed if loading from a pickle file.
|
850 |
+
|
851 |
+
selection_mask : list[bool]
|
852 |
+
If using select_k_features, you must pass `model.selection_mask_` here.
|
853 |
+
Not needed if loading from a pickle file.
|
854 |
+
|
855 |
+
nout : int, default=1
|
856 |
+
Number of outputs of the model.
|
857 |
+
Not needed if loading from a pickle file.
|
858 |
+
|
859 |
+
pysr_kwargs : dict
|
860 |
+
Any other keyword arguments to initialize the PySRRegressor object.
|
861 |
+
These will overwrite those stored in the pickle file.
|
862 |
+
Not needed if loading from a pickle file.
|
863 |
+
|
864 |
+
Returns
|
865 |
+
-------
|
866 |
+
model : PySRRegressor
|
867 |
+
The model with fitted equations.
|
868 |
+
"""
|
869 |
+
if os.path.splitext(equation_file)[1] != ".pkl":
|
870 |
+
pkl_filename = _csv_filename_to_pkl_filename(equation_file)
|
871 |
+
else:
|
872 |
+
pkl_filename = equation_file
|
873 |
+
|
874 |
+
# Try to load model from <equation_file>.pkl
|
875 |
+
print(f"Checking if {pkl_filename} exists...")
|
876 |
+
if os.path.exists(pkl_filename):
|
877 |
+
print(f"Loading model from {pkl_filename}")
|
878 |
+
assert binary_operators is None
|
879 |
+
assert unary_operators is None
|
880 |
+
assert n_features_in is None
|
881 |
+
with open(pkl_filename, "rb") as f:
|
882 |
+
model = pkl.load(f)
|
883 |
+
# Update any parameters if necessary, such as
|
884 |
+
# extra_sympy_mappings:
|
885 |
+
model.set_params(**pysr_kwargs)
|
886 |
+
if "equations_" not in model.__dict__ or model.equations_ is None:
|
887 |
+
model.refresh()
|
888 |
+
|
889 |
+
return model
|
890 |
+
|
891 |
+
# Else, we re-create it.
|
892 |
+
print(
|
893 |
+
f"{equation_file} does not exist, "
|
894 |
+
"so we must create the model from scratch."
|
895 |
+
)
|
896 |
+
assert binary_operators is not None
|
897 |
+
assert unary_operators is not None
|
898 |
+
assert n_features_in is not None
|
899 |
+
|
900 |
+
# TODO: copy .bkup file if exists.
|
901 |
+
model = cls(
|
902 |
+
equation_file=equation_file,
|
903 |
+
binary_operators=binary_operators,
|
904 |
+
unary_operators=unary_operators,
|
905 |
+
**pysr_kwargs,
|
906 |
+
)
|
907 |
+
|
908 |
+
model.nout_ = nout
|
909 |
+
model.n_features_in_ = n_features_in
|
910 |
+
|
911 |
+
if feature_names_in is None:
|
912 |
+
model.feature_names_in_ = [f"x{i}" for i in range(n_features_in)]
|
913 |
+
else:
|
914 |
+
assert len(feature_names_in) == n_features_in
|
915 |
+
model.feature_names_in_ = feature_names_in
|
916 |
+
|
917 |
+
if selection_mask is None:
|
918 |
+
model.selection_mask_ = np.ones(n_features_in, dtype=bool)
|
919 |
+
else:
|
920 |
+
model.selection_mask_ = selection_mask
|
921 |
+
|
922 |
+
model.refresh(checkpoint_file=equation_file)
|
923 |
+
|
924 |
+
return model
|
925 |
+
|
926 |
def __repr__(self):
|
927 |
"""
|
928 |
Prints all current equations fitted by the model.
|
|
|
2204 |
pkl_basename = base + ".pkl"
|
2205 |
|
2206 |
return os.path.join(dirname, pkl_basename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test/test.py
CHANGED
@@ -4,7 +4,7 @@ import inspect
|
|
4 |
import unittest
|
5 |
import numpy as np
|
6 |
from sklearn import model_selection
|
7 |
-
from pysr import PySRRegressor
|
8 |
from pysr.sr import (
|
9 |
run_feature_selection,
|
10 |
_handle_feature_selection,
|
@@ -300,7 +300,7 @@ class TestPipeline(unittest.TestCase):
|
|
300 |
equation_filename = str(rand_dir / "equation.csv")
|
301 |
with open(equation_filename + (".bkup" if from_backup else ""), "w") as f:
|
302 |
f.write(csv_file_data)
|
303 |
-
model =
|
304 |
equation_filename,
|
305 |
n_features_in=5,
|
306 |
feature_names_in=["f0", "f1", "f2", "f3", "f4"],
|
@@ -334,7 +334,7 @@ class TestPipeline(unittest.TestCase):
|
|
334 |
|
335 |
# lambda functions are removed from the pickling, so we need
|
336 |
# to pass it during the loading:
|
337 |
-
model2 =
|
338 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
339 |
)
|
340 |
|
@@ -346,7 +346,7 @@ class TestPipeline(unittest.TestCase):
|
|
346 |
os.remove(file_to_delete)
|
347 |
|
348 |
pickle_file = rand_dir / "equations.pkl"
|
349 |
-
model3 =
|
350 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
351 |
)
|
352 |
np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
|
|
|
4 |
import unittest
|
5 |
import numpy as np
|
6 |
from sklearn import model_selection
|
7 |
+
from pysr import PySRRegressor
|
8 |
from pysr.sr import (
|
9 |
run_feature_selection,
|
10 |
_handle_feature_selection,
|
|
|
300 |
equation_filename = str(rand_dir / "equation.csv")
|
301 |
with open(equation_filename + (".bkup" if from_backup else ""), "w") as f:
|
302 |
f.write(csv_file_data)
|
303 |
+
model = PySRRegressor.from_file(
|
304 |
equation_filename,
|
305 |
n_features_in=5,
|
306 |
feature_names_in=["f0", "f1", "f2", "f3", "f4"],
|
|
|
334 |
|
335 |
# lambda functions are removed from the pickling, so we need
|
336 |
# to pass it during the loading:
|
337 |
+
model2 = PySRRegressor.from_file(
|
338 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
339 |
)
|
340 |
|
|
|
346 |
os.remove(file_to_delete)
|
347 |
|
348 |
pickle_file = rand_dir / "equations.pkl"
|
349 |
+
model3 = PySRRegressor.from_file(
|
350 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
351 |
)
|
352 |
np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
|