Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
70b842a
1
Parent(s):
e957e34
Save options to PySRRegressor
Browse files- pysr/julia_helpers.py +12 -2
- pysr/sr.py +25 -15
pysr/julia_helpers.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
"""Functions for initializing the Julia environment and installing deps."""
|
2 |
import warnings
|
3 |
|
|
|
4 |
from juliacall import convert as jl_convert # type: ignore
|
5 |
|
6 |
from .julia_import import jl
|
@@ -8,6 +9,9 @@ from .julia_import import jl
|
|
8 |
jl.seval("using Serialization: Serialization")
|
9 |
jl.seval("using PythonCall: PythonCall")
|
10 |
|
|
|
|
|
|
|
11 |
|
12 |
def install(*args, **kwargs):
|
13 |
del args, kwargs
|
@@ -35,10 +39,16 @@ def jl_array(x):
|
|
35 |
return jl_convert(jl.Array, x)
|
36 |
|
37 |
|
38 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
if s is None:
|
40 |
return s
|
41 |
buf = jl.IOBuffer()
|
42 |
jl.write(buf, jl_array(s))
|
43 |
jl.seekstart(buf)
|
44 |
-
return
|
|
|
1 |
"""Functions for initializing the Julia environment and installing deps."""
|
2 |
import warnings
|
3 |
|
4 |
+
import numpy as np
|
5 |
from juliacall import convert as jl_convert # type: ignore
|
6 |
|
7 |
from .julia_import import jl
|
|
|
9 |
jl.seval("using Serialization: Serialization")
|
10 |
jl.seval("using PythonCall: PythonCall")
|
11 |
|
12 |
+
Serialization = jl.Serialization
|
13 |
+
PythonCall = jl.PythonCall
|
14 |
+
|
15 |
|
16 |
def install(*args, **kwargs):
|
17 |
del args, kwargs
|
|
|
39 |
return jl_convert(jl.Array, x)
|
40 |
|
41 |
|
42 |
+
def jl_serialize(obj):
|
43 |
+
buf = jl.IOBuffer()
|
44 |
+
Serialization.serialize(buf, obj)
|
45 |
+
return np.array(jl.take_b(buf))
|
46 |
+
|
47 |
+
|
48 |
+
def jl_deserialize(s):
|
49 |
if s is None:
|
50 |
return s
|
51 |
buf = jl.IOBuffer()
|
52 |
jl.write(buf, jl_array(s))
|
53 |
jl.seekstart(buf)
|
54 |
+
return Serialization.deserialize(buf)
|
pysr/sr.py
CHANGED
@@ -33,10 +33,12 @@ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2
|
|
33 |
from .export_torch import sympy2torch
|
34 |
from .feature_selection import run_feature_selection
|
35 |
from .julia_helpers import (
|
|
|
36 |
_escape_filename,
|
37 |
_load_cluster_manager,
|
38 |
jl_array,
|
39 |
-
|
|
|
40 |
)
|
41 |
from .julia_import import SymbolicRegression, jl
|
42 |
from .utils import (
|
@@ -602,11 +604,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
602 |
Path to the temporary equations directory.
|
603 |
equation_file_ : str
|
604 |
Output equation file name produced by the julia backend.
|
605 |
-
|
606 |
The serialized state for the julia SymbolicRegression.jl backend (after fitting),
|
607 |
stored as an array of uint8, produced by Julia's Serialization.serialize function.
|
608 |
-
julia_state_
|
609 |
The deserialized state.
|
|
|
|
|
|
|
|
|
610 |
equation_file_contents_ : list[pandas.DataFrame]
|
611 |
Contents of the equation file output by the Julia backend.
|
612 |
show_pickle_warnings_ : bool
|
@@ -1053,7 +1059,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1053 |
serialization.
|
1054 |
|
1055 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
1056 |
-
`
|
1057 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
1058 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
1059 |
to be serialized. Note: Jax and Torch format equations are also removed
|
@@ -1121,15 +1127,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1121 |
)
|
1122 |
return self.equations_
|
1123 |
|
|
|
|
|
|
|
|
|
1124 |
@property
|
1125 |
def julia_state_(self):
|
1126 |
-
return
|
1127 |
|
1128 |
@property
|
1129 |
def raw_julia_state_(self):
|
1130 |
warnings.warn(
|
1131 |
"PySRRegressor.raw_julia_state_ is now deprecated. "
|
1132 |
-
"Please use PySRRegressor.julia_state_ instead, or
|
1133 |
"for the raw stream of bytes.",
|
1134 |
FutureWarning,
|
1135 |
)
|
@@ -1675,6 +1685,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1675 |
define_helper_functions=False,
|
1676 |
)
|
1677 |
|
|
|
|
|
1678 |
# Convert data to desired precision
|
1679 |
test_X = np.array(X)
|
1680 |
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
|
@@ -1718,7 +1730,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1718 |
else:
|
1719 |
jl_y_variable_names = None
|
1720 |
|
1721 |
-
|
1722 |
out = SymbolicRegression.equation_search(
|
1723 |
jl_X,
|
1724 |
jl_y,
|
@@ -1741,12 +1753,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1741 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
1742 |
verbosity=int(self.verbosity),
|
1743 |
)
|
1744 |
-
|
1745 |
|
1746 |
-
|
1747 |
-
buf = jl.IOBuffer()
|
1748 |
-
jl.Serialization.serialize(buf, out)
|
1749 |
-
self.raw_julia_state_stream_ = np.array(jl.take_b(buf))
|
1750 |
|
1751 |
# Set attributes
|
1752 |
self.equations_ = self.get_hof()
|
@@ -1810,10 +1819,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1810 |
Fitted estimator.
|
1811 |
"""
|
1812 |
# Init attributes that are not specified in BaseEstimator
|
1813 |
-
if self.warm_start and hasattr(self, "
|
1814 |
pass
|
1815 |
else:
|
1816 |
-
if hasattr(self, "
|
1817 |
warnings.warn(
|
1818 |
"The discovered expressions are being reset. "
|
1819 |
"Please set `warm_start=True` if you wish to continue "
|
@@ -1823,7 +1832,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1823 |
self.equations_ = None
|
1824 |
self.nout_ = 1
|
1825 |
self.selection_mask_ = None
|
1826 |
-
self.
|
|
|
1827 |
self.X_units_ = None
|
1828 |
self.y_units_ = None
|
1829 |
|
|
|
33 |
from .export_torch import sympy2torch
|
34 |
from .feature_selection import run_feature_selection
|
35 |
from .julia_helpers import (
|
36 |
+
PythonCall,
|
37 |
_escape_filename,
|
38 |
_load_cluster_manager,
|
39 |
jl_array,
|
40 |
+
jl_deserialize,
|
41 |
+
jl_serialize,
|
42 |
)
|
43 |
from .julia_import import SymbolicRegression, jl
|
44 |
from .utils import (
|
|
|
604 |
Path to the temporary equations directory.
|
605 |
equation_file_ : str
|
606 |
Output equation file name produced by the julia backend.
|
607 |
+
julia_state_stream_ : ndarray
|
608 |
The serialized state for the julia SymbolicRegression.jl backend (after fitting),
|
609 |
stored as an array of uint8, produced by Julia's Serialization.serialize function.
|
610 |
+
julia_state_
|
611 |
The deserialized state.
|
612 |
+
julia_options_stream_ : ndarray
|
613 |
+
The serialized julia options, stored as an array of uint8,
|
614 |
+
julia_options_
|
615 |
+
The deserialized julia options.
|
616 |
equation_file_contents_ : list[pandas.DataFrame]
|
617 |
Contents of the equation file output by the Julia backend.
|
618 |
show_pickle_warnings_ : bool
|
|
|
1059 |
serialization.
|
1060 |
|
1061 |
Thus, for `PySRRegressor` to support pickle serialization, the
|
1062 |
+
`julia_state_stream_` attribute must be hidden from pickle. This will
|
1063 |
prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
|
1064 |
but does allow all other attributes of a fitted `PySRRegressor` estimator
|
1065 |
to be serialized. Note: Jax and Torch format equations are also removed
|
|
|
1127 |
)
|
1128 |
return self.equations_
|
1129 |
|
1130 |
+
@property
|
1131 |
+
def julia_options_(self):
|
1132 |
+
return jl_deserialize(self.julia_options_stream_)
|
1133 |
+
|
1134 |
@property
|
1135 |
def julia_state_(self):
|
1136 |
+
return jl_deserialize(self.julia_state_stream_)
|
1137 |
|
1138 |
@property
|
1139 |
def raw_julia_state_(self):
|
1140 |
warnings.warn(
|
1141 |
"PySRRegressor.raw_julia_state_ is now deprecated. "
|
1142 |
+
"Please use PySRRegressor.julia_state_ instead, or julia_state_stream_ "
|
1143 |
"for the raw stream of bytes.",
|
1144 |
FutureWarning,
|
1145 |
)
|
|
|
1685 |
define_helper_functions=False,
|
1686 |
)
|
1687 |
|
1688 |
+
self.julia_options_stream_ = jl_serialize(options)
|
1689 |
+
|
1690 |
# Convert data to desired precision
|
1691 |
test_X = np.array(X)
|
1692 |
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
|
|
|
1730 |
else:
|
1731 |
jl_y_variable_names = None
|
1732 |
|
1733 |
+
PythonCall.GC.disable()
|
1734 |
out = SymbolicRegression.equation_search(
|
1735 |
jl_X,
|
1736 |
jl_y,
|
|
|
1753 |
progress=progress and self.verbosity > 0 and len(y.shape) == 1,
|
1754 |
verbosity=int(self.verbosity),
|
1755 |
)
|
1756 |
+
PythonCall.GC.enable()
|
1757 |
|
1758 |
+
self.julia_state_stream_ = jl_serialize(out)
|
|
|
|
|
|
|
1759 |
|
1760 |
# Set attributes
|
1761 |
self.equations_ = self.get_hof()
|
|
|
1819 |
Fitted estimator.
|
1820 |
"""
|
1821 |
# Init attributes that are not specified in BaseEstimator
|
1822 |
+
if self.warm_start and hasattr(self, "julia_state_stream_"):
|
1823 |
pass
|
1824 |
else:
|
1825 |
+
if hasattr(self, "julia_state_stream_"):
|
1826 |
warnings.warn(
|
1827 |
"The discovered expressions are being reset. "
|
1828 |
"Please set `warm_start=True` if you wish to continue "
|
|
|
1832 |
self.equations_ = None
|
1833 |
self.nout_ = 1
|
1834 |
self.selection_mask_ = None
|
1835 |
+
self.julia_state_stream_ = None
|
1836 |
+
self.julia_options_stream_ = None
|
1837 |
self.X_units_ = None
|
1838 |
self.y_units_ = None
|
1839 |
|