MilesCranmer commited on
Commit
d72c643
1 Parent(s): 7a42396

Save raw bytes so can warm-restart in new python session

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +10 -2
  2. pysr/sr.py +34 -22
pysr/julia_helpers.py CHANGED
@@ -22,8 +22,7 @@ import juliapkg
22
  from juliacall import Main as jl
23
  from juliacall import convert as jl_convert
24
 
25
- jl.seval("using PythonCall: PythonCall")
26
- PythonCall = jl.PythonCall
27
 
28
  juliainfo = None
29
  julia_initialized = False
@@ -63,3 +62,12 @@ def jl_array(x):
63
  if x is None:
64
  return None
65
  return jl_convert(jl.Array, x)
 
 
 
 
 
 
 
 
 
 
22
  from juliacall import Main as jl
23
  from juliacall import convert as jl_convert
24
 
25
+ jl.seval("using Serialization: Serialization")
 
26
 
27
  juliainfo = None
28
  julia_initialized = False
 
62
  if x is None:
63
  return None
64
  return jl_convert(jl.Array, x)
65
+
66
+
67
+ def jl_deserialize_s(s):
68
+ if s is None:
69
+ return s
70
+ buf = jl.IOBuffer()
71
+ jl.write(buf, jl_array(s))
72
+ jl.seekstart(buf)
73
+ return jl.Serialization.deserialize(buf)
pysr/sr.py CHANGED
@@ -34,12 +34,11 @@ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2
34
  from .export_torch import sympy2torch
35
  from .feature_selection import run_feature_selection
36
  from .julia_helpers import (
37
- PythonCall,
38
  _escape_filename,
39
  _load_cluster_manager,
40
  jl,
41
  jl_array,
42
- jl_convert,
43
  )
44
  from .utils import (
45
  _csv_filename_to_pkl_filename,
@@ -614,8 +613,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
614
  Path to the temporary equations directory.
615
  equation_file_ : str
616
  Output equation file name produced by the julia backend.
617
- raw_julia_state_ : tuple[list[PyCall.jlwrap], PyCall.jlwrap]
618
- The state for the julia SymbolicRegression.jl backend post fitting.
619
  equation_file_contents_ : list[pandas.DataFrame]
620
  Contents of the equation file output by the Julia backend.
621
  show_pickle_warnings_ : bool
@@ -1048,22 +1047,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1048
  serialization.
1049
 
1050
  Thus, for `PySRRegressor` to support pickle serialization, the
1051
- `raw_julia_state_` attribute must be hidden from pickle. This will
1052
  prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
1053
  but does allow all other attributes of a fitted `PySRRegressor` estimator
1054
  to be serialized. Note: Jax and Torch format equations are also removed
1055
  from the pickled instance.
1056
  """
1057
  state = self.__dict__
1058
- show_pickle_warning = not (
1059
- "show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
1060
- )
1061
- if "raw_julia_state_" in state and show_pickle_warning:
1062
- warnings.warn(
1063
- "raw_julia_state_ cannot be pickled and will be removed from the "
1064
- "serialized instance. This will prevent a `warm_start` fit of any "
1065
- "model that is deserialized via `pickle.load()`."
1066
- )
1067
  state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
1068
  for state_key in state_keys_containing_lambdas:
1069
  if state[state_key] is not None and show_pickle_warning:
@@ -1072,7 +1062,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1072
  "serialized instance. When loading the model, please redefine "
1073
  f"`{state_key}` at runtime."
1074
  )
1075
- state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
1076
  pickled_state = {
1077
  key: (None if key in state_keys_to_clear else value)
1078
  for key, value in state.items()
@@ -1122,6 +1112,20 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1122
  )
1123
  return self.equations_
1124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1125
  def get_best(self, index=None):
1126
  """
1127
  Get best equation using `model_selection`.
@@ -1724,7 +1728,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1724
  # Python's garbage collection is unaware of them.
1725
  jl._equation_search_args = (jl_X, jl_y)
1726
  jl._equation_search_kwargs = namedtuple(
1727
- "K",
1728
  (
1729
  "weights",
1730
  "niterations",
@@ -1754,18 +1758,26 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1754
  options=options,
1755
  numprocs=cprocs,
1756
  parallelism=parallelism,
1757
- saved_state=self.raw_julia_state_,
1758
  return_state=True,
1759
  addprocs_function=cluster_manager,
1760
  heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
1761
  progress=progress and self.verbosity > 0 and len(y.shape) == 1,
1762
  verbosity=int(self.verbosity),
1763
  )
1764
- self.raw_julia_state_ = jl.seval(
1765
- "deepcopy(SymbolicRegression.equation_search(deepcopy(_equation_search_args)...; deepcopy(_equation_search_kwargs)...))"
 
 
 
 
 
 
 
1766
  )
1767
  jl._equation_search_args = None
1768
  jl._equation_search_kwargs = None
 
1769
 
1770
  # Set attributes
1771
  self.equations_ = self.get_hof()
@@ -1829,10 +1841,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1829
  Fitted estimator.
1830
  """
1831
  # Init attributes that are not specified in BaseEstimator
1832
- if self.warm_start and hasattr(self, "raw_julia_state_"):
1833
  pass
1834
  else:
1835
- if hasattr(self, "raw_julia_state_"):
1836
  warnings.warn(
1837
  "The discovered expressions are being reset. "
1838
  "Please set `warm_start=True` if you wish to continue "
@@ -1842,7 +1854,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1842
  self.equations_ = None
1843
  self.nout_ = 1
1844
  self.selection_mask_ = None
1845
- self.raw_julia_state_ = None
1846
  self.X_units_ = None
1847
  self.y_units_ = None
1848
 
 
34
  from .export_torch import sympy2torch
35
  from .feature_selection import run_feature_selection
36
  from .julia_helpers import (
 
37
  _escape_filename,
38
  _load_cluster_manager,
39
  jl,
40
  jl_array,
41
+ jl_deserialize_s,
42
  )
43
  from .utils import (
44
  _csv_filename_to_pkl_filename,
 
613
  Path to the temporary equations directory.
614
  equation_file_ : str
615
  Output equation file name produced by the julia backend.
616
+ raw_julia_state_stream_ : ndarray
617
+ The serialized state for the julia SymbolicRegression.jl backend (after fitting).
618
  equation_file_contents_ : list[pandas.DataFrame]
619
  Contents of the equation file output by the Julia backend.
620
  show_pickle_warnings_ : bool
 
1047
  serialization.
1048
 
1049
  Thus, for `PySRRegressor` to support pickle serialization, the
1050
+ `raw_julia_state_stream_` attribute must be hidden from pickle. This will
1051
  prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
1052
  but does allow all other attributes of a fitted `PySRRegressor` estimator
1053
  to be serialized. Note: Jax and Torch format equations are also removed
1054
  from the pickled instance.
1055
  """
1056
  state = self.__dict__
 
 
 
 
 
 
 
 
 
1057
  state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
1058
  for state_key in state_keys_containing_lambdas:
1059
  if state[state_key] is not None and show_pickle_warning:
 
1062
  "serialized instance. When loading the model, please redefine "
1063
  f"`{state_key}` at runtime."
1064
  )
1065
+ state_keys_to_clear = state_keys_containing_lambdas
1066
  pickled_state = {
1067
  key: (None if key in state_keys_to_clear else value)
1068
  for key, value in state.items()
 
1112
  )
1113
  return self.equations_
1114
 
1115
+ @property
1116
+ def julia_state(self):
1117
+ return jl_deserialize_s(self.raw_julia_state_stream_)
1118
+
1119
+ @property
1120
+ def raw_julia_state_(self):
1121
+ warnings.warn(
1122
+ "PySRRegressor.raw_julia_state_ is now deprecated. "
1123
+ "Please use PySRRegressor.julia_state instead, or `raw_julia_state_stream_` "
1124
+ "for the raw stream of bytes.",
1125
+ FutureWarning,
1126
+ )
1127
+ return self.julia_state
1128
+
1129
  def get_best(self, index=None):
1130
  """
1131
  Get best equation using `model_selection`.
 
1728
  # Python's garbage collection is unaware of them.
1729
  jl._equation_search_args = (jl_X, jl_y)
1730
  jl._equation_search_kwargs = namedtuple(
1731
+ "equation_search_kwargs",
1732
  (
1733
  "weights",
1734
  "niterations",
 
1758
  options=options,
1759
  numprocs=cprocs,
1760
  parallelism=parallelism,
1761
+ saved_state=self.julia_state,
1762
  return_state=True,
1763
  addprocs_function=cluster_manager,
1764
  heap_size_hint_in_bytes=self.heap_size_hint_in_bytes,
1765
  progress=progress and self.verbosity > 0 and len(y.shape) == 1,
1766
  verbosity=int(self.verbosity),
1767
  )
1768
+ output_stream = jl.seval(
1769
+ """
1770
+ let args = deepcopy(_equation_search_args), kwargs=deepcopy(_equation_search_kwargs)
1771
+ out = SymbolicRegression.equation_search(args...; kwargs...)
1772
+ buf = IOBuffer()
1773
+ Serialization.serialize(buf, out)
1774
+ take!(buf)
1775
+ end
1776
+ """
1777
  )
1778
  jl._equation_search_args = None
1779
  jl._equation_search_kwargs = None
1780
+ self.raw_julia_state_stream_ = np.array(output_stream).copy()
1781
 
1782
  # Set attributes
1783
  self.equations_ = self.get_hof()
 
1841
  Fitted estimator.
1842
  """
1843
  # Init attributes that are not specified in BaseEstimator
1844
+ if self.warm_start and hasattr(self, "raw_julia_state_stream_"):
1845
  pass
1846
  else:
1847
+ if hasattr(self, "raw_julia_state_stream_"):
1848
  warnings.warn(
1849
  "The discovered expressions are being reset. "
1850
  "Please set `warm_start=True` if you wish to continue "
 
1854
  self.equations_ = None
1855
  self.nout_ = 1
1856
  self.selection_mask_ = None
1857
+ self.raw_julia_state_stream_ = None
1858
  self.X_units_ = None
1859
  self.y_units_ = None
1860