MilesCranmer commited on
Commit
70b842a
1 Parent(s): e957e34

Save options to PySRRegressor

Browse files
Files changed (2) hide show
  1. pysr/julia_helpers.py +12 -2
  2. pysr/sr.py +25 -15
pysr/julia_helpers.py CHANGED
@@ -1,6 +1,7 @@
1
  """Functions for initializing the Julia environment and installing deps."""
2
  import warnings
3
 
 
4
  from juliacall import convert as jl_convert # type: ignore
5
 
6
  from .julia_import import jl
@@ -8,6 +9,9 @@ from .julia_import import jl
8
  jl.seval("using Serialization: Serialization")
9
  jl.seval("using PythonCall: PythonCall")
10
 
 
 
 
11
 
12
  def install(*args, **kwargs):
13
  del args, kwargs
@@ -35,10 +39,16 @@ def jl_array(x):
35
  return jl_convert(jl.Array, x)
36
 
37
 
38
- def jl_deserialize_s(s):
 
 
 
 
 
 
39
  if s is None:
40
  return s
41
  buf = jl.IOBuffer()
42
  jl.write(buf, jl_array(s))
43
  jl.seekstart(buf)
44
- return jl.Serialization.deserialize(buf)
 
1
  """Functions for initializing the Julia environment and installing deps."""
2
  import warnings
3
 
4
+ import numpy as np
5
  from juliacall import convert as jl_convert # type: ignore
6
 
7
  from .julia_import import jl
 
9
  jl.seval("using Serialization: Serialization")
10
  jl.seval("using PythonCall: PythonCall")
11
 
12
+ Serialization = jl.Serialization
13
+ PythonCall = jl.PythonCall
14
+
15
 
16
  def install(*args, **kwargs):
17
  del args, kwargs
 
39
  return jl_convert(jl.Array, x)
40
 
41
 
42
+ def jl_serialize(obj):
43
+ buf = jl.IOBuffer()
44
+ Serialization.serialize(buf, obj)
45
+ return np.array(jl.take_b(buf))
46
+
47
+
48
+ def jl_deserialize(s):
49
  if s is None:
50
  return s
51
  buf = jl.IOBuffer()
52
  jl.write(buf, jl_array(s))
53
  jl.seekstart(buf)
54
+ return Serialization.deserialize(buf)
pysr/sr.py CHANGED
@@ -33,10 +33,12 @@ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2
33
  from .export_torch import sympy2torch
34
  from .feature_selection import run_feature_selection
35
  from .julia_helpers import (
 
36
  _escape_filename,
37
  _load_cluster_manager,
38
  jl_array,
39
- jl_deserialize_s,
 
40
  )
41
  from .julia_import import SymbolicRegression, jl
42
  from .utils import (
@@ -602,11 +604,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
602
  Path to the temporary equations directory.
603
  equation_file_ : str
604
  Output equation file name produced by the julia backend.
605
- raw_julia_state_stream_ : ndarray
606
  The serialized state for the julia SymbolicRegression.jl backend (after fitting),
607
  stored as an array of uint8, produced by Julia's Serialization.serialize function.
608
- julia_state_ : ndarray
609
  The deserialized state.
 
 
 
 
610
  equation_file_contents_ : list[pandas.DataFrame]
611
  Contents of the equation file output by the Julia backend.
612
  show_pickle_warnings_ : bool
@@ -1053,7 +1059,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1053
  serialization.
1054
 
1055
  Thus, for `PySRRegressor` to support pickle serialization, the
1056
- `raw_julia_state_stream_` attribute must be hidden from pickle. This will
1057
  prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
1058
  but does allow all other attributes of a fitted `PySRRegressor` estimator
1059
  to be serialized. Note: Jax and Torch format equations are also removed
@@ -1121,15 +1127,19 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1121
  )
1122
  return self.equations_
1123
 
 
 
 
 
1124
  @property
1125
  def julia_state_(self):
1126
- return jl_deserialize_s(self.raw_julia_state_stream_)
1127
 
1128
  @property
1129
  def raw_julia_state_(self):
1130
  warnings.warn(
1131
  "PySRRegressor.raw_julia_state_ is now deprecated. "
1132
- "Please use PySRRegressor.julia_state_ instead, or raw_julia_state_stream_ "
1133
  "for the raw stream of bytes.",
1134
  FutureWarning,
1135
  )
@@ -1675,6 +1685,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1675
  define_helper_functions=False,
1676
  )
1677
 
 
 
1678
  # Convert data to desired precision
1679
  test_X = np.array(X)
1680
  is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
@@ -1718,7 +1730,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1718
  else:
1719
  jl_y_variable_names = None
1720
 
1721
- jl.PythonCall.GC.disable()
1722
  out = SymbolicRegression.equation_search(
1723
  jl_X,
1724
  jl_y,
@@ -1741,12 +1753,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1741
  progress=progress and self.verbosity > 0 and len(y.shape) == 1,
1742
  verbosity=int(self.verbosity),
1743
  )
1744
- jl.PythonCall.GC.enable()
1745
 
1746
- # Serialize output (for pickling)
1747
- buf = jl.IOBuffer()
1748
- jl.Serialization.serialize(buf, out)
1749
- self.raw_julia_state_stream_ = np.array(jl.take_b(buf))
1750
 
1751
  # Set attributes
1752
  self.equations_ = self.get_hof()
@@ -1810,10 +1819,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1810
  Fitted estimator.
1811
  """
1812
  # Init attributes that are not specified in BaseEstimator
1813
- if self.warm_start and hasattr(self, "raw_julia_state_stream_"):
1814
  pass
1815
  else:
1816
- if hasattr(self, "raw_julia_state_stream_"):
1817
  warnings.warn(
1818
  "The discovered expressions are being reset. "
1819
  "Please set `warm_start=True` if you wish to continue "
@@ -1823,7 +1832,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1823
  self.equations_ = None
1824
  self.nout_ = 1
1825
  self.selection_mask_ = None
1826
- self.raw_julia_state_stream_ = None
 
1827
  self.X_units_ = None
1828
  self.y_units_ = None
1829
 
 
33
  from .export_torch import sympy2torch
34
  from .feature_selection import run_feature_selection
35
  from .julia_helpers import (
36
+ PythonCall,
37
  _escape_filename,
38
  _load_cluster_manager,
39
  jl_array,
40
+ jl_deserialize,
41
+ jl_serialize,
42
  )
43
  from .julia_import import SymbolicRegression, jl
44
  from .utils import (
 
604
  Path to the temporary equations directory.
605
  equation_file_ : str
606
  Output equation file name produced by the julia backend.
607
+ julia_state_stream_ : ndarray
608
  The serialized state for the julia SymbolicRegression.jl backend (after fitting),
609
  stored as an array of uint8, produced by Julia's Serialization.serialize function.
610
+ julia_state_
611
  The deserialized state.
612
+ julia_options_stream_ : ndarray
613
+ The serialized julia options, stored as an array of uint8,
614
+ julia_options_
615
+ The deserialized julia options.
616
  equation_file_contents_ : list[pandas.DataFrame]
617
  Contents of the equation file output by the Julia backend.
618
  show_pickle_warnings_ : bool
 
1059
  serialization.
1060
 
1061
  Thus, for `PySRRegressor` to support pickle serialization, the
1062
+ `julia_state_stream_` attribute must be hidden from pickle. This will
1063
  prevent the `warm_start` of any model that is loaded via `pickle.loads()`,
1064
  but does allow all other attributes of a fitted `PySRRegressor` estimator
1065
  to be serialized. Note: Jax and Torch format equations are also removed
 
1127
  )
1128
  return self.equations_
1129
 
1130
+ @property
1131
+ def julia_options_(self):
1132
+ return jl_deserialize(self.julia_options_stream_)
1133
+
1134
  @property
1135
  def julia_state_(self):
1136
+ return jl_deserialize(self.julia_state_stream_)
1137
 
1138
  @property
1139
  def raw_julia_state_(self):
1140
  warnings.warn(
1141
  "PySRRegressor.raw_julia_state_ is now deprecated. "
1142
+ "Please use PySRRegressor.julia_state_ instead, or julia_state_stream_ "
1143
  "for the raw stream of bytes.",
1144
  FutureWarning,
1145
  )
 
1685
  define_helper_functions=False,
1686
  )
1687
 
1688
+ self.julia_options_stream_ = jl_serialize(options)
1689
+
1690
  # Convert data to desired precision
1691
  test_X = np.array(X)
1692
  is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
 
1730
  else:
1731
  jl_y_variable_names = None
1732
 
1733
+ PythonCall.GC.disable()
1734
  out = SymbolicRegression.equation_search(
1735
  jl_X,
1736
  jl_y,
 
1753
  progress=progress and self.verbosity > 0 and len(y.shape) == 1,
1754
  verbosity=int(self.verbosity),
1755
  )
1756
+ PythonCall.GC.enable()
1757
 
1758
+ self.julia_state_stream_ = jl_serialize(out)
 
 
 
1759
 
1760
  # Set attributes
1761
  self.equations_ = self.get_hof()
 
1819
  Fitted estimator.
1820
  """
1821
  # Init attributes that are not specified in BaseEstimator
1822
+ if self.warm_start and hasattr(self, "julia_state_stream_"):
1823
  pass
1824
  else:
1825
+ if hasattr(self, "julia_state_stream_"):
1826
  warnings.warn(
1827
  "The discovered expressions are being reset. "
1828
  "Please set `warm_start=True` if you wish to continue "
 
1832
  self.equations_ = None
1833
  self.nout_ = 1
1834
  self.selection_mask_ = None
1835
+ self.julia_state_stream_ = None
1836
+ self.julia_options_stream_ = None
1837
  self.X_units_ = None
1838
  self.y_units_ = None
1839