Spaces:
Running
Running
MilesCranmer
commited on
Commit
·
dde0ef7
1
Parent(s):
85371bb
Remove extra_sympy_mappings from pickle file
Browse files- pysr/sr.py +17 -2
pysr/sr.py
CHANGED
@@ -562,6 +562,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
562 |
equation_file_contents_ : list[pandas.DataFrame]
|
563 |
Contents of the equation file output by the Julia backend.
|
564 |
|
|
|
|
|
|
|
565 |
Notes
|
566 |
-----
|
567 |
Most default parameters have been tuned over several example equations,
|
@@ -873,14 +876,26 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
873 |
from the pickled instance.
|
874 |
"""
|
875 |
state = self.__dict__
|
876 |
-
|
|
|
|
|
|
|
877 |
warnings.warn(
|
878 |
"raw_julia_state_ cannot be pickled and will be removed from the "
|
879 |
"serialized instance. This will prevent a `warm_start` fit of any "
|
880 |
"model that is deserialized via `pickle.load()`."
|
881 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
882 |
pickled_state = {
|
883 |
-
key: None if key
|
884 |
for key, value in state.items()
|
885 |
}
|
886 |
if ("equations_" in pickled_state) and (
|
|
|
562 |
equation_file_contents_ : list[pandas.DataFrame]
|
563 |
Contents of the equation file output by the Julia backend.
|
564 |
|
565 |
+
show_pickle_warnings_ : bool
|
566 |
+
Whether to show warnings about what attributes can be pickled.
|
567 |
+
|
568 |
Notes
|
569 |
-----
|
570 |
Most default parameters have been tuned over several example equations,
|
|
|
876 |
from the pickled instance.
|
877 |
"""
|
878 |
state = self.__dict__
|
879 |
+
show_pickle_warning = not (
|
880 |
+
"show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
|
881 |
+
)
|
882 |
+
if "raw_julia_state_" in state and show_pickle_warning:
|
883 |
warnings.warn(
|
884 |
"raw_julia_state_ cannot be pickled and will be removed from the "
|
885 |
"serialized instance. This will prevent a `warm_start` fit of any "
|
886 |
"model that is deserialized via `pickle.load()`."
|
887 |
)
|
888 |
+
state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
|
889 |
+
for state_key in state_keys_containing_lambdas:
|
890 |
+
if state[state_key] is not None and show_pickle_warning:
|
891 |
+
warnings.warn(
|
892 |
+
f"`{state_key}` cannot be pickled and will be removed from the "
|
893 |
+
"serialized instance. When loading the model, please redefine "
|
894 |
+
f"`{state_key}` at runtime."
|
895 |
+
)
|
896 |
+
state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
|
897 |
pickled_state = {
|
898 |
+
key: (None if key in state_keys_to_clear else value)
|
899 |
for key, value in state.items()
|
900 |
}
|
901 |
if ("equations_" in pickled_state) and (
|