MilesCranmer commited on
Commit
1099283
·
unverified ·
2 Parent(s): 8da5000 9433a83

Merge pull request #167 from MilesCranmer/loading

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. README.md +9 -1
  3. pysr/sr.py +184 -12
  4. test/test.py +89 -4
  5. test/test_jax.py +1 -1
.gitignore CHANGED
@@ -3,6 +3,7 @@
3
  *.csv
4
  *.csv.out*
5
  *.bkup
 
6
  performance*txt
7
  *.out
8
  trials*
 
3
  *.csv
4
  *.csv.out*
5
  *.bkup
6
+ *.pkl
7
  performance*txt
8
  *.out
9
  trials*
README.md CHANGED
@@ -162,7 +162,15 @@ This arrow in the `pick` column indicates which equation is currently selected b
162
  SymPy format (`sympy_format` - which you can also get with `model.sympy()`), and even JAX and PyTorch format
163
  (both of which are differentiable - which you can get with `model.jax()` and `model.pytorch()`).
164
 
165
- Note that `PySRRegressor` stores the state of the last search, and will restart from where you left off the next time you call `.fit()`. This will cause problems if significant changes are made to the search parameters (like changing the operators). You can run `model.reset()` to reset the state.
 
 
 
 
 
 
 
 
166
 
167
  There are several other useful features such as denoising (e.g., `denoising=True`),
168
  feature selection (e.g., `select_k_features=3`).
 
162
  SymPy format (`sympy_format` - which you can also get with `model.sympy()`), and even JAX and PyTorch format
163
  (both of which are differentiable - which you can get with `model.jax()` and `model.pytorch()`).
164
 
165
+ Note that `PySRRegressor` stores the state of the last search, and will restart from where you left off the next time you call `.fit()`, assuming you have set `warm_start=True`.
166
+ This will cause problems if significant changes are made to the search parameters (like changing the operators). You can run `model.reset()` to reset the state.
167
+
168
+ You will notice that PySR will save two files: `hall_of_fame...csv` and `hall_of_fame...pkl`.
169
+ The csv file is a list of equations and their losses, and the pkl file is a saved state of the model.
170
+ You may load the model from the `pkl` file with:
171
+ ```python
172
+ model = PySRRegressor.from_file("hall_of_fame.2022-08-10_100832.281.pkl")
173
+ ```
174
 
175
  There are several other useful features such as denoising (e.g., `denoising=True`),
176
  feature selection (e.g., `select_k_features=3`).
pysr/sr.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import sys
3
  import numpy as np
@@ -8,6 +9,7 @@ import re
8
  import tempfile
9
  import shutil
10
  from pathlib import Path
 
11
  from datetime import datetime
12
  import warnings
13
  from multiprocessing import cpu_count
@@ -562,6 +564,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
562
  equation_file_contents_ : list[pandas.DataFrame]
563
  Contents of the equation file output by the Julia backend.
564
 
 
 
 
565
  Notes
566
  -----
567
  Most default parameters have been tuned over several example equations,
@@ -805,6 +810,119 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
805
  f"{k} is not a valid keyword argument for PySRRegressor."
806
  )
807
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
  def __repr__(self):
809
  """
810
  Prints all current equations fitted by the model.
@@ -873,17 +991,31 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
873
  from the pickled instance.
874
  """
875
  state = self.__dict__
876
- if "raw_julia_state_" in state:
 
 
 
877
  warnings.warn(
878
  "raw_julia_state_ cannot be pickled and will be removed from the "
879
  "serialized instance. This will prevent a `warm_start` fit of any "
880
  "model that is deserialized via `pickle.load()`."
881
  )
 
 
 
 
 
 
 
 
 
882
  pickled_state = {
883
- key: None if key == "raw_julia_state_" else value
884
  for key, value in state.items()
885
  }
886
- if "equations_" in pickled_state:
 
 
887
  pickled_state["output_torch_format"] = False
888
  pickled_state["output_jax_format"] = False
889
  if self.nout_ == 1:
@@ -906,6 +1038,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
906
  ]
907
  return pickled_state
908
 
 
 
 
 
 
 
 
 
 
 
909
  @property
910
  def equations(self): # pragma: no cover
911
  warnings.warn(
@@ -1606,8 +1748,20 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1606
  y,
1607
  )
1608
 
1609
- # Fitting procedure
1610
- return self._run(X, y, mutated_params, weights=weights, seed=seed)
 
 
 
 
 
 
 
 
 
 
 
 
1611
 
1612
  def refresh(self, checkpoint_file=None):
1613
  """
@@ -1619,10 +1773,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1619
  checkpoint_file : str, default=None
1620
  Path to checkpoint hall of fame file to be loaded.
1621
  """
1622
- check_is_fitted(self, attributes=["equation_file_"])
1623
  if checkpoint_file:
1624
  self.equation_file_ = checkpoint_file
1625
  self.equation_file_contents_ = None
 
1626
  self.equations_ = self.get_hof()
1627
 
1628
  def predict(self, X, index=None):
@@ -1812,10 +1966,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1812
  if self.nout_ > 1:
1813
  all_outputs = []
1814
  for i in range(1, self.nout_ + 1):
1815
- df = pd.read_csv(
1816
- str(self.equation_file_) + f".out{i}" + ".bkup",
1817
- sep="|",
1818
- )
1819
  # Rename Complexity column to complexity:
1820
  df.rename(
1821
  columns={
@@ -1828,7 +1982,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1828
 
1829
  all_outputs.append(df)
1830
  else:
1831
- all_outputs = [pd.read_csv(str(self.equation_file_) + ".bkup", sep="|")]
 
 
 
1832
  all_outputs[-1].rename(
1833
  columns={
1834
  "Complexity": "complexity",
@@ -1886,7 +2043,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1886
 
1887
  ret_outputs = []
1888
 
1889
- for output in self.equation_file_contents_:
 
 
1890
 
1891
  scores = []
1892
  lastMSE = None
@@ -2035,3 +2194,16 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
2035
  clf, threshold=-np.inf, max_features=select_k_features, prefit=True
2036
  )
2037
  return selector.get_support(indices=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
  import os
3
  import sys
4
  import numpy as np
 
9
  import tempfile
10
  import shutil
11
  from pathlib import Path
12
+ import pickle as pkl
13
  from datetime import datetime
14
  import warnings
15
  from multiprocessing import cpu_count
 
564
  equation_file_contents_ : list[pandas.DataFrame]
565
  Contents of the equation file output by the Julia backend.
566
 
567
+ show_pickle_warnings_ : bool
568
+ Whether to show warnings about what attributes can be pickled.
569
+
570
  Notes
571
  -----
572
  Most default parameters have been tuned over several example equations,
 
810
  f"{k} is not a valid keyword argument for PySRRegressor."
811
  )
812
 
813
+ @classmethod
814
+ def from_file(
815
+ cls,
816
+ equation_file,
817
+ *,
818
+ binary_operators=None,
819
+ unary_operators=None,
820
+ n_features_in=None,
821
+ feature_names_in=None,
822
+ selection_mask=None,
823
+ nout=1,
824
+ **pysr_kwargs,
825
+ ):
826
+ """
827
+ Create a model from a saved model checkpoint or equation file.
828
+
829
+ Parameters
830
+ ----------
831
+ equation_file : str
832
+ Path to a pickle file containing a saved model, or a csv file
833
+ containing equations.
834
+
835
+ binary_operators : list[str]
836
+ The same binary operators used when creating the model.
837
+ Not needed if loading from a pickle file.
838
+
839
+ unary_operators : list[str]
840
+ The same unary operators used when creating the model.
841
+ Not needed if loading from a pickle file.
842
+
843
+ n_features_in : int
844
+ Number of features passed to the model.
845
+ Not needed if loading from a pickle file.
846
+
847
+ feature_names_in : list[str]
848
+ Names of the features passed to the model.
849
+ Not needed if loading from a pickle file.
850
+
851
+ selection_mask : list[bool]
852
+ If using select_k_features, you must pass `model.selection_mask_` here.
853
+ Not needed if loading from a pickle file.
854
+
855
+ nout : int, default=1
856
+ Number of outputs of the model.
857
+ Not needed if loading from a pickle file.
858
+
859
+ pysr_kwargs : dict
860
+ Any other keyword arguments to initialize the PySRRegressor object.
861
+ These will overwrite those stored in the pickle file.
862
+ Not needed if loading from a pickle file.
863
+
864
+ Returns
865
+ -------
866
+ model : PySRRegressor
867
+ The model with fitted equations.
868
+ """
869
+ if os.path.splitext(equation_file)[1] != ".pkl":
870
+ pkl_filename = _csv_filename_to_pkl_filename(equation_file)
871
+ else:
872
+ pkl_filename = equation_file
873
+
874
+ # Try to load model from <equation_file>.pkl
875
+ print(f"Checking if {pkl_filename} exists...")
876
+ if os.path.exists(pkl_filename):
877
+ print(f"Loading model from {pkl_filename}")
878
+ assert binary_operators is None
879
+ assert unary_operators is None
880
+ assert n_features_in is None
881
+ with open(pkl_filename, "rb") as f:
882
+ model = pkl.load(f)
883
+ # Update any parameters if necessary, such as
884
+ # extra_sympy_mappings:
885
+ model.set_params(**pysr_kwargs)
886
+ if "equations_" not in model.__dict__ or model.equations_ is None:
887
+ model.refresh()
888
+
889
+ return model
890
+
891
+ # Else, we re-create it.
892
+ print(
893
+ f"{equation_file} does not exist, "
894
+ "so we must create the model from scratch."
895
+ )
896
+ assert binary_operators is not None
897
+ assert unary_operators is not None
898
+ assert n_features_in is not None
899
+
900
+ # TODO: copy .bkup file if exists.
901
+ model = cls(
902
+ equation_file=equation_file,
903
+ binary_operators=binary_operators,
904
+ unary_operators=unary_operators,
905
+ **pysr_kwargs,
906
+ )
907
+
908
+ model.nout_ = nout
909
+ model.n_features_in_ = n_features_in
910
+
911
+ if feature_names_in is None:
912
+ model.feature_names_in_ = [f"x{i}" for i in range(n_features_in)]
913
+ else:
914
+ assert len(feature_names_in) == n_features_in
915
+ model.feature_names_in_ = feature_names_in
916
+
917
+ if selection_mask is None:
918
+ model.selection_mask_ = np.ones(n_features_in, dtype=bool)
919
+ else:
920
+ model.selection_mask_ = selection_mask
921
+
922
+ model.refresh(checkpoint_file=equation_file)
923
+
924
+ return model
925
+
926
  def __repr__(self):
927
  """
928
  Prints all current equations fitted by the model.
 
991
  from the pickled instance.
992
  """
993
  state = self.__dict__
994
+ show_pickle_warning = not (
995
+ "show_pickle_warnings_" in state and not state["show_pickle_warnings_"]
996
+ )
997
+ if "raw_julia_state_" in state and show_pickle_warning:
998
  warnings.warn(
999
  "raw_julia_state_ cannot be pickled and will be removed from the "
1000
  "serialized instance. This will prevent a `warm_start` fit of any "
1001
  "model that is deserialized via `pickle.load()`."
1002
  )
1003
+ state_keys_containing_lambdas = ["extra_sympy_mappings", "extra_torch_mappings"]
1004
+ for state_key in state_keys_containing_lambdas:
1005
+ if state[state_key] is not None and show_pickle_warning:
1006
+ warnings.warn(
1007
+ f"`{state_key}` cannot be pickled and will be removed from the "
1008
+ "serialized instance. When loading the model, please redefine "
1009
+ f"`{state_key}` at runtime."
1010
+ )
1011
+ state_keys_to_clear = ["raw_julia_state_"] + state_keys_containing_lambdas
1012
  pickled_state = {
1013
+ key: (None if key in state_keys_to_clear else value)
1014
  for key, value in state.items()
1015
  }
1016
+ if ("equations_" in pickled_state) and (
1017
+ pickled_state["equations_"] is not None
1018
+ ):
1019
  pickled_state["output_torch_format"] = False
1020
  pickled_state["output_jax_format"] = False
1021
  if self.nout_ == 1:
 
1038
  ]
1039
  return pickled_state
1040
 
1041
+ def _checkpoint(self):
1042
+ """Saves the model's current state to a checkpoint file.
1043
+
1044
+ This should only be used internally by PySRRegressor."""
1045
+ # Save model state:
1046
+ self.show_pickle_warnings_ = False
1047
+ with open(_csv_filename_to_pkl_filename(self.equation_file_), "wb") as f:
1048
+ pkl.dump(self, f)
1049
+ self.show_pickle_warnings_ = True
1050
+
1051
  @property
1052
  def equations(self): # pragma: no cover
1053
  warnings.warn(
 
1748
  y,
1749
  )
1750
 
1751
+ # Initially, just save model parameters, so that
1752
+ # it can be loaded from an early exit:
1753
+ if not self.temp_equation_file:
1754
+ self._checkpoint()
1755
+
1756
+ # Perform the search:
1757
+ self._run(X, y, mutated_params, weights=weights, seed=seed)
1758
+
1759
+ # Then, after fit, we save again, so the pickle file contains
1760
+ # the equations:
1761
+ if not self.temp_equation_file:
1762
+ self._checkpoint()
1763
+
1764
+ return self
1765
 
1766
  def refresh(self, checkpoint_file=None):
1767
  """
 
1773
  checkpoint_file : str, default=None
1774
  Path to checkpoint hall of fame file to be loaded.
1775
  """
 
1776
  if checkpoint_file:
1777
  self.equation_file_ = checkpoint_file
1778
  self.equation_file_contents_ = None
1779
+ check_is_fitted(self, attributes=["equation_file_"])
1780
  self.equations_ = self.get_hof()
1781
 
1782
  def predict(self, X, index=None):
 
1966
  if self.nout_ > 1:
1967
  all_outputs = []
1968
  for i in range(1, self.nout_ + 1):
1969
+ cur_filename = str(self.equation_file_) + f".out{i}" + ".bkup"
1970
+ if not os.path.exists(cur_filename):
1971
+ cur_filename = str(self.equation_file_) + f".out{i}"
1972
+ df = pd.read_csv(cur_filename, sep="|")
1973
  # Rename Complexity column to complexity:
1974
  df.rename(
1975
  columns={
 
1982
 
1983
  all_outputs.append(df)
1984
  else:
1985
+ filename = str(self.equation_file_) + ".bkup"
1986
+ if not os.path.exists(filename):
1987
+ filename = str(self.equation_file_)
1988
+ all_outputs = [pd.read_csv(filename, sep="|")]
1989
  all_outputs[-1].rename(
1990
  columns={
1991
  "Complexity": "complexity",
 
2043
 
2044
  ret_outputs = []
2045
 
2046
+ equation_file_contents = copy.deepcopy(self.equation_file_contents_)
2047
+
2048
+ for output in equation_file_contents:
2049
 
2050
  scores = []
2051
  lastMSE = None
 
2194
  clf, threshold=-np.inf, max_features=select_k_features, prefit=True
2195
  )
2196
  return selector.get_support(indices=True)
2197
+
2198
+
2199
+ def _csv_filename_to_pkl_filename(csv_filename) -> str:
2200
+ # Assume that the csv filename is of the form "foo.csv"
2201
+ assert str(csv_filename).endswith(".csv")
2202
+
2203
+ dirname = str(os.path.dirname(csv_filename))
2204
+ basename = str(os.path.basename(csv_filename))
2205
+ base = str(os.path.splitext(basename)[0])
2206
+
2207
+ pkl_basename = base + ".pkl"
2208
+
2209
+ return os.path.join(dirname, pkl_basename)
test/test.py CHANGED
@@ -5,13 +5,18 @@ import unittest
5
  import numpy as np
6
  from sklearn import model_selection
7
  from pysr import PySRRegressor
8
- from pysr.sr import run_feature_selection, _handle_feature_selection
 
 
 
 
9
  from sklearn.utils.estimator_checks import check_estimator
10
  import sympy
11
  import pandas as pd
12
  import warnings
13
  import pickle as pkl
14
  import tempfile
 
15
 
16
  DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters
17
  DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default
@@ -135,7 +140,7 @@ class TestPipeline(unittest.TestCase):
135
  # These tests are flaky, so don't fail test:
136
  try:
137
  np.testing.assert_almost_equal(
138
- model.predict(X.copy())[:, 0], X[:, 0] ** 2, decimal=4
139
  )
140
  except AssertionError:
141
  print("Error in test_multioutput_weighted_with_callable_temp_equation")
@@ -144,7 +149,7 @@ class TestPipeline(unittest.TestCase):
144
 
145
  try:
146
  np.testing.assert_almost_equal(
147
- model.predict(X.copy())[:, 1], X[:, 1] ** 2, decimal=4
148
  )
149
  except AssertionError:
150
  print("Error in test_multioutput_weighted_with_callable_temp_equation")
@@ -280,6 +285,72 @@ class TestPipeline(unittest.TestCase):
280
  model.fit(X.values, y.values, Xresampled=Xresampled.values)
281
  self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  class TestBest(unittest.TestCase):
285
  def setUp(self):
@@ -330,7 +401,7 @@ class TestBest(unittest.TestCase):
330
  X = self.X
331
  y = self.y
332
  for f in [self.model.predict, self.equations_.iloc[-1]["lambda_format"]]:
333
- np.testing.assert_almost_equal(f(X), y, decimal=4)
334
 
335
 
336
  class TestFeatureSelection(unittest.TestCase):
@@ -364,6 +435,20 @@ class TestFeatureSelection(unittest.TestCase):
364
  class TestMiscellaneous(unittest.TestCase):
365
  """Test miscellaneous functions."""
366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  def test_deprecation(self):
368
  """Ensure that deprecation works as expected.
369
 
 
5
  import numpy as np
6
  from sklearn import model_selection
7
  from pysr import PySRRegressor
8
+ from pysr.sr import (
9
+ run_feature_selection,
10
+ _handle_feature_selection,
11
+ _csv_filename_to_pkl_filename,
12
+ )
13
  from sklearn.utils.estimator_checks import check_estimator
14
  import sympy
15
  import pandas as pd
16
  import warnings
17
  import pickle as pkl
18
  import tempfile
19
+ from pathlib import Path
20
 
21
  DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters
22
  DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default
 
140
  # These tests are flaky, so don't fail test:
141
  try:
142
  np.testing.assert_almost_equal(
143
+ model.predict(X.copy())[:, 0], X[:, 0] ** 2, decimal=3
144
  )
145
  except AssertionError:
146
  print("Error in test_multioutput_weighted_with_callable_temp_equation")
 
149
 
150
  try:
151
  np.testing.assert_almost_equal(
152
+ model.predict(X.copy())[:, 1], X[:, 1] ** 2, decimal=3
153
  )
154
  except AssertionError:
155
  print("Error in test_multioutput_weighted_with_callable_temp_equation")
 
285
  model.fit(X.values, y.values, Xresampled=Xresampled.values)
286
  self.assertLess(np.average((model.predict(X.values) - y.values) ** 2), 1e-4)
287
 
288
+ def test_load_model(self):
289
+ """See if we can load a ran model from the equation file."""
290
+ csv_file_data = """
291
+ Complexity|MSE|Equation
292
+ 1|0.19951081|1.9762075
293
+ 3|0.12717344|(f0 + 1.4724599)
294
+ 4|0.104823045|pow_abs(2.2683423, cos(f3))"""
295
+ # Strip the indents:
296
+ csv_file_data = "\n".join([l.strip() for l in csv_file_data.split("\n")])
297
+
298
+ for from_backup in [False, True]:
299
+ rand_dir = Path(tempfile.mkdtemp())
300
+ equation_filename = str(rand_dir / "equation.csv")
301
+ with open(equation_filename + (".bkup" if from_backup else ""), "w") as f:
302
+ f.write(csv_file_data)
303
+ model = PySRRegressor.from_file(
304
+ equation_filename,
305
+ n_features_in=5,
306
+ feature_names_in=["f0", "f1", "f2", "f3", "f4"],
307
+ binary_operators=["+", "*", "/", "-", "^"],
308
+ unary_operators=["cos"],
309
+ )
310
+ X = self.rstate.rand(100, 5)
311
+ y_truth = 2.2683423 ** np.cos(X[:, 3])
312
+ y_test = model.predict(X, 2)
313
+
314
+ np.testing.assert_allclose(y_truth, y_test)
315
+
316
+ def test_load_model_simple(self):
317
+ # Test that we can simply load a model from its equation file.
318
+ y = self.X[:, [0, 1]] ** 2
319
+ model = PySRRegressor(
320
+ # Test that passing a single operator works:
321
+ unary_operators="sq(x) = x^2",
322
+ binary_operators="plus",
323
+ extra_sympy_mappings={"sq": lambda x: x**2},
324
+ **self.default_test_kwargs,
325
+ procs=0,
326
+ denoise=True,
327
+ early_stop_condition="stop_if(loss, complexity) = loss < 0.05 && complexity == 2",
328
+ )
329
+ rand_dir = Path(tempfile.mkdtemp())
330
+ equation_file = rand_dir / "equations.csv"
331
+ model.set_params(temp_equation_file=False)
332
+ model.set_params(equation_file=equation_file)
333
+ model.fit(self.X, y)
334
+
335
+ # lambda functions are removed from the pickling, so we need
336
+ # to pass it during the loading:
337
+ model2 = PySRRegressor.from_file(
338
+ model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
339
+ )
340
+
341
+ np.testing.assert_allclose(model.predict(self.X), model2.predict(self.X))
342
+
343
+ # Try again, but using only the pickle file:
344
+ for file_to_delete in [str(equation_file), str(equation_file) + ".bkup"]:
345
+ if os.path.exists(file_to_delete):
346
+ os.remove(file_to_delete)
347
+
348
+ pickle_file = rand_dir / "equations.pkl"
349
+ model3 = PySRRegressor.from_file(
350
+ model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
351
+ )
352
+ np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
353
+
354
 
355
  class TestBest(unittest.TestCase):
356
  def setUp(self):
 
401
  X = self.X
402
  y = self.y
403
  for f in [self.model.predict, self.equations_.iloc[-1]["lambda_format"]]:
404
+ np.testing.assert_almost_equal(f(X), y, decimal=3)
405
 
406
 
407
  class TestFeatureSelection(unittest.TestCase):
 
435
  class TestMiscellaneous(unittest.TestCase):
436
  """Test miscellaneous functions."""
437
 
438
+ def test_csv_to_pkl_conversion(self):
439
+ """Test that csv filename to pkl filename works as expected."""
440
+ tmpdir = Path(tempfile.mkdtemp())
441
+ equation_file = tmpdir / "equations.389479384.28378374.csv"
442
+ expected_pkl_file = tmpdir / "equations.389479384.28378374.pkl"
443
+
444
+ # First, test inputting the paths:
445
+ test_pkl_file = _csv_filename_to_pkl_filename(equation_file)
446
+ self.assertEqual(test_pkl_file, str(expected_pkl_file))
447
+
448
+ # Next, test inputting the strings.
449
+ test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
450
+ self.assertEqual(test_pkl_file, str(expected_pkl_file))
451
+
452
  def test_deprecation(self):
453
  """Ensure that deprecation works as expected.
454
 
test/test_jax.py CHANGED
@@ -76,7 +76,7 @@ class TestJAX(unittest.TestCase):
76
  np.testing.assert_almost_equal(
77
  np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
78
  np.square(np.cos(X[:, 1])), # Select feature 1
79
- decimal=4,
80
  )
81
 
82
  def test_feature_selection_custom_operators(self):
 
76
  np.testing.assert_almost_equal(
77
  np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
78
  np.square(np.cos(X[:, 1])), # Select feature 1
79
+ decimal=3,
80
  )
81
 
82
  def test_feature_selection_custom_operators(self):