MilesCranmer commited on
Commit
b8a97f1
1 Parent(s): b53e7fa

Use .pkl instead of .csv.pkl

Browse files
Files changed (2) hide show
  1. pysr/sr.py +28 -10
  2. test/test.py +20 -1
pysr/sr.py CHANGED
@@ -930,7 +930,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
930
  This should only be used internally by PySRRegressor."""
931
  # Save model state:
932
  self.show_pickle_warnings_ = False
933
- with open(str(self.equation_file_) + ".pkl", "wb") as f:
934
  pkl.dump(self, f)
935
  self.show_pickle_warnings_ = True
936
 
@@ -1636,14 +1636,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1636
 
1637
  # Initially, just save model parameters, so that
1638
  # it can be loaded from an early exit:
1639
- self._checkpoint()
 
1640
 
1641
  # Perform the search:
1642
  self._run(X, y, mutated_params, weights=weights, seed=seed)
1643
 
1644
  # Then, after fit, we save again, so the pickle file contains
1645
  # the equations:
1646
- self._checkpoint()
 
1647
 
1648
  return self
1649
 
@@ -2077,6 +2079,17 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
2077
  return selector.get_support(indices=True)
2078
 
2079
 
 
 
 
 
 
 
 
 
 
 
 
2080
  def load(
2081
  equation_file,
2082
  *,
@@ -2094,7 +2107,8 @@ def load(
2094
  Parameters
2095
  ----------
2096
  equation_file : str
2097
- Path to a csv file containing equations.
 
2098
 
2099
  binary_operators : list[str], default=["+", "-", "*", "/"]
2100
  The same binary operators used when creating the model.
@@ -2123,14 +2137,19 @@ def load(
2123
  model : PySRRegressor
2124
  The model with fitted equations.
2125
  """
 
 
 
 
 
2126
  # Try to load model from <equation_file>.pkl
2127
- print(f"Checking if {equation_file}.pkl exists...")
2128
- if os.path.exists(str(equation_file) + ".pkl"):
2129
- print(f"Loading model from {equation_file}.pkl.")
2130
  assert binary_operators is None
2131
  assert unary_operators is None
2132
  assert n_features_in is None
2133
- with open(str(equation_file) + ".pkl", "rb") as f:
2134
  model = pkl.load(f)
2135
  # Update any parameters if necessary, such as
2136
  # extra_sympy_mappings:
@@ -2142,8 +2161,7 @@ def load(
2142
 
2143
  # Else, we re-create it.
2144
  print(
2145
- f"{equation_file}.pkl does not exist, "
2146
- "so we must create the model from scratch."
2147
  )
2148
  assert binary_operators is not None
2149
  assert unary_operators is not None
 
930
  This should only be used internally by PySRRegressor."""
931
  # Save model state:
932
  self.show_pickle_warnings_ = False
933
+ with open(_csv_filename_to_pkl_filename(self.equation_file_), "wb") as f:
934
  pkl.dump(self, f)
935
  self.show_pickle_warnings_ = True
936
 
 
1636
 
1637
  # Initially, just save model parameters, so that
1638
  # it can be loaded from an early exit:
1639
+ if not self.temp_equation_file:
1640
+ self._checkpoint()
1641
 
1642
  # Perform the search:
1643
  self._run(X, y, mutated_params, weights=weights, seed=seed)
1644
 
1645
  # Then, after fit, we save again, so the pickle file contains
1646
  # the equations:
1647
+ if not self.temp_equation_file:
1648
+ self._checkpoint()
1649
 
1650
  return self
1651
 
 
2079
  return selector.get_support(indices=True)
2080
 
2081
 
2082
+ def _csv_filename_to_pkl_filename(csv_filename) -> str:
2083
+ # Assume that the csv filename is of the form "foo.csv"
2084
+ dirname = str(os.path.dirname(csv_filename))
2085
+ basename = str(os.path.basename(csv_filename))
2086
+ base = str(os.path.splitext(basename)[0])
2087
+
2088
+ pkl_basename = base + ".pkl"
2089
+
2090
+ return os.path.join(dirname, pkl_basename)
2091
+
2092
+
2093
  def load(
2094
  equation_file,
2095
  *,
 
2107
  Parameters
2108
  ----------
2109
  equation_file : str
2110
+ Path to a csv file containing equations, or a pickle file
2111
+ containing the model.
2112
 
2113
  binary_operators : list[str], default=["+", "-", "*", "/"]
2114
  The same binary operators used when creating the model.
 
2137
  model : PySRRegressor
2138
  The model with fitted equations.
2139
  """
2140
+ if os.path.splitext(equation_file)[1] != ".pkl":
2141
+ pkl_filename = _csv_filename_to_pkl_filename(equation_file)
2142
+ else:
2143
+ pkl_filename = equation_file
2144
+
2145
  # Try to load model from <equation_file>.pkl
2146
+ print(f"Checking if {pkl_filename} exists...")
2147
+ if os.path.exists(pkl_filename):
2148
+ print(f"Loading model from {pkl_filename}")
2149
  assert binary_operators is None
2150
  assert unary_operators is None
2151
  assert n_features_in is None
2152
+ with open(pkl_filename, "rb") as f:
2153
  model = pkl.load(f)
2154
  # Update any parameters if necessary, such as
2155
  # extra_sympy_mappings:
 
2161
 
2162
  # Else, we re-create it.
2163
  print(
2164
+ f"{equation_file} does not exist, " "so we must create the model from scratch."
 
2165
  )
2166
  assert binary_operators is not None
2167
  assert unary_operators is not None
test/test.py CHANGED
@@ -5,7 +5,11 @@ import unittest
5
  import numpy as np
6
  from sklearn import model_selection
7
  from pysr import PySRRegressor, load
8
- from pysr.sr import run_feature_selection, _handle_feature_selection
 
 
 
 
9
  from sklearn.utils.estimator_checks import check_estimator
10
  import sympy
11
  import pandas as pd
@@ -341,6 +345,7 @@ class TestPipeline(unittest.TestCase):
341
  if os.path.exists(file_to_delete):
342
  os.remove(file_to_delete)
343
 
 
344
  model3 = load(
345
  model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
346
  )
@@ -430,6 +435,20 @@ class TestFeatureSelection(unittest.TestCase):
430
  class TestMiscellaneous(unittest.TestCase):
431
  """Test miscellaneous functions."""
432
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  def test_deprecation(self):
434
  """Ensure that deprecation works as expected.
435
 
 
5
  import numpy as np
6
  from sklearn import model_selection
7
  from pysr import PySRRegressor, load
8
+ from pysr.sr import (
9
+ run_feature_selection,
10
+ _handle_feature_selection,
11
+ _csv_filename_to_pkl_filename,
12
+ )
13
  from sklearn.utils.estimator_checks import check_estimator
14
  import sympy
15
  import pandas as pd
 
345
  if os.path.exists(file_to_delete):
346
  os.remove(file_to_delete)
347
 
348
+ pickle_file = rand_dir / "equations.pkl"
349
  model3 = load(
350
  model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
351
  )
 
435
  class TestMiscellaneous(unittest.TestCase):
436
  """Test miscellaneous functions."""
437
 
438
+ def test_csv_to_pkl_conversion(self):
439
+ """Test that csv filename to pkl filename works as expected."""
440
+ tmpdir = Path(tempfile.mkdtemp())
441
+ equation_file = tmpdir / "equations.389479384.28378374.csv"
442
+ expected_pkl_file = tmpdir / "equations.389479384.28378374.pkl"
443
+
444
+ # First, test inputting the paths:
445
+ test_pkl_file = _csv_filename_to_pkl_filename(equation_file)
446
+ self.assertEqual(test_pkl_file, str(expected_pkl_file))
447
+
448
+ # Next, test inputting the strings.
449
+ test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
450
+ self.assertEqual(test_pkl_file, str(expected_pkl_file))
451
+
452
  def test_deprecation(self):
453
  """Ensure that deprecation works as expected.
454