Spaces:
Running
Running
tttc3
commited on
Commit
•
bd90cfc
1
Parent(s):
3ef5500
Added pickle support
Browse files- pysr/export_numpy.py +4 -1
- pysr/sr.py +36 -0
- test/test.py +4 -4
pysr/export_numpy.py
CHANGED
@@ -13,7 +13,6 @@ class CallableEquation:
|
|
13 |
self._sympy_symbols = sympy_symbols
|
14 |
self._selection = selection
|
15 |
self._variable_names = variable_names
|
16 |
-
self._lambda = lambdify(sympy_symbols, eqn)
|
17 |
|
18 |
def __repr__(self):
|
19 |
return f"PySRFunction(X=>{self._sympy})"
|
@@ -35,3 +34,7 @@ class CallableEquation:
|
|
35 |
)
|
36 |
X = X[:, self._selection]
|
37 |
return self._lambda(*X.T) * np.ones(expected_shape)
|
|
|
|
|
|
|
|
|
|
13 |
self._sympy_symbols = sympy_symbols
|
14 |
self._selection = selection
|
15 |
self._variable_names = variable_names
|
|
|
16 |
|
17 |
def __repr__(self):
|
18 |
return f"PySRFunction(X=>{self._sympy})"
|
|
|
34 |
)
|
35 |
X = X[:, self._selection]
|
36 |
return self._lambda(*X.T) * np.ones(expected_shape)
|
37 |
+
|
38 |
+
@property
|
39 |
+
def _lambda(self):
|
40 |
+
return lambdify(self._sympy_symbols, self._sympy)
|
pysr/sr.py
CHANGED
@@ -816,6 +816,42 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
816 |
output += "]"
|
817 |
return output
|
818 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
819 |
@property
|
820 |
def equations(self): # pragma: no cover
|
821 |
warnings.warn(
|
|
|
816 |
output += "]"
|
817 |
return output
|
818 |
|
819 |
+
def __getstate__(self):
|
820 |
+
"""
|
821 |
+
Handles pickle serialization for PySRRegressor.
|
822 |
+
|
823 |
+
The Scikit-learn standard requires estimators to be serializable via
|
824 |
+
`pickle.dumps()`. However, `PyCall.jlwrap` does not support pickle
|
825 |
+
serialization.
|
826 |
+
|
827 |
+
Thus, for `PySRRegressor` to support pickle serialization, the
|
828 |
+
`raw_julia_state_` attribute must be hidden from pickle. This will
|
829 |
+
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
830 |
+
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
831 |
+
to be serialized. Note: Jax and Torch format equations are also removed
|
832 |
+
from the pickled instance.
|
833 |
+
"""
|
834 |
+
warnings.warn(
|
835 |
+
"raw_julia_state_ cannot be pickled and will be removed from the "
|
836 |
+
"serialized instance. This will prevent a `warm_start` fit of any "
|
837 |
+
"model that is deserialized via `pickle.loads()`."
|
838 |
+
)
|
839 |
+
state = self.__dict__
|
840 |
+
pickled_state = {
|
841 |
+
key: None if key == "raw_julia_state_" else value
|
842 |
+
for key, value in state.items()
|
843 |
+
}
|
844 |
+
if "equations_" in pickled_state:
|
845 |
+
pickled_state["output_torch_format"] = False
|
846 |
+
pickled_state["output_jax_format"] = False
|
847 |
+
pickled_columns = ~pickled_state["equations_"].columns.isin(
|
848 |
+
["jax_format", "torch_format"]
|
849 |
+
)
|
850 |
+
pickled_state["equations_"] = (
|
851 |
+
pickled_state["equations_"].loc[:, pickled_columns].copy()
|
852 |
+
)
|
853 |
+
return pickled_state
|
854 |
+
|
855 |
@property
|
856 |
def equations(self): # pragma: no cover
|
857 |
warnings.warn(
|
test/test.py
CHANGED
@@ -348,18 +348,18 @@ class TestMiscellaneous(unittest.TestCase):
|
|
348 |
max_evals=10000, verbosity=0, progress=False
|
349 |
) # Return early.
|
350 |
check_generator = check_estimator(model, generate_only=True)
|
|
|
351 |
for (_, check) in check_generator:
|
352 |
-
if "pickle" in check.func.__name__:
|
353 |
-
# Skip pickling tests.
|
354 |
-
continue
|
355 |
-
|
356 |
try:
|
357 |
with warnings.catch_warnings():
|
358 |
warnings.simplefilter("ignore")
|
359 |
check(model)
|
360 |
print("Passed", check.func.__name__)
|
361 |
except Exception as e:
|
|
|
362 |
print("Failed", check.func.__name__, "with:")
|
363 |
# Add a leading tab to error message, which
|
364 |
# might be multi-line:
|
365 |
print("\n".join([(" " * 4) + row for row in str(e).split("\n")]))
|
|
|
|
|
|
348 |
max_evals=10000, verbosity=0, progress=False
|
349 |
) # Return early.
|
350 |
check_generator = check_estimator(model, generate_only=True)
|
351 |
+
exception_messages = []
|
352 |
for (_, check) in check_generator:
|
|
|
|
|
|
|
|
|
353 |
try:
|
354 |
with warnings.catch_warnings():
|
355 |
warnings.simplefilter("ignore")
|
356 |
check(model)
|
357 |
print("Passed", check.func.__name__)
|
358 |
except Exception as e:
|
359 |
+
exception_messages.append(f"{check.func.__name__}: {e}\n")
|
360 |
print("Failed", check.func.__name__, "with:")
|
361 |
# Add a leading tab to error message, which
|
362 |
# might be multi-line:
|
363 |
print("\n".join([(" " * 4) + row for row in str(e).split("\n")]))
|
364 |
+
# If any checks failed don't let the test pass.
|
365 |
+
self.assertEqual([], exception_messages)
|