Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
b8a97f1
1
Parent(s):
b53e7fa
Use .pkl instead of .csv.pkl
Browse files- pysr/sr.py +28 -10
- test/test.py +20 -1
pysr/sr.py
CHANGED
@@ -930,7 +930,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
930 |
This should only be used internally by PySRRegressor."""
|
931 |
# Save model state:
|
932 |
self.show_pickle_warnings_ = False
|
933 |
-
with open(
|
934 |
pkl.dump(self, f)
|
935 |
self.show_pickle_warnings_ = True
|
936 |
|
@@ -1636,14 +1636,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1636 |
|
1637 |
# Initially, just save model parameters, so that
|
1638 |
# it can be loaded from an early exit:
|
1639 |
-
self.
|
|
|
1640 |
|
1641 |
# Perform the search:
|
1642 |
self._run(X, y, mutated_params, weights=weights, seed=seed)
|
1643 |
|
1644 |
# Then, after fit, we save again, so the pickle file contains
|
1645 |
# the equations:
|
1646 |
-
self.
|
|
|
1647 |
|
1648 |
return self
|
1649 |
|
@@ -2077,6 +2079,17 @@ def run_feature_selection(X, y, select_k_features, random_state=None):
|
|
2077 |
return selector.get_support(indices=True)
|
2078 |
|
2079 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2080 |
def load(
|
2081 |
equation_file,
|
2082 |
*,
|
@@ -2094,7 +2107,8 @@ def load(
|
|
2094 |
Parameters
|
2095 |
----------
|
2096 |
equation_file : str
|
2097 |
-
Path to a csv file containing equations
|
|
|
2098 |
|
2099 |
binary_operators : list[str], default=["+", "-", "*", "/"]
|
2100 |
The same binary operators used when creating the model.
|
@@ -2123,14 +2137,19 @@ def load(
|
|
2123 |
model : PySRRegressor
|
2124 |
The model with fitted equations.
|
2125 |
"""
|
|
|
|
|
|
|
|
|
|
|
2126 |
# Try to load model from <equation_file>.pkl
|
2127 |
-
print(f"Checking if {
|
2128 |
-
if os.path.exists(
|
2129 |
-
print(f"Loading model from {
|
2130 |
assert binary_operators is None
|
2131 |
assert unary_operators is None
|
2132 |
assert n_features_in is None
|
2133 |
-
with open(
|
2134 |
model = pkl.load(f)
|
2135 |
# Update any parameters if necessary, such as
|
2136 |
# extra_sympy_mappings:
|
@@ -2142,8 +2161,7 @@ def load(
|
|
2142 |
|
2143 |
# Else, we re-create it.
|
2144 |
print(
|
2145 |
-
f"{equation_file}
|
2146 |
-
"so we must create the model from scratch."
|
2147 |
)
|
2148 |
assert binary_operators is not None
|
2149 |
assert unary_operators is not None
|
|
|
930 |
This should only be used internally by PySRRegressor."""
|
931 |
# Save model state:
|
932 |
self.show_pickle_warnings_ = False
|
933 |
+
with open(_csv_filename_to_pkl_filename(self.equation_file_), "wb") as f:
|
934 |
pkl.dump(self, f)
|
935 |
self.show_pickle_warnings_ = True
|
936 |
|
|
|
1636 |
|
1637 |
# Initially, just save model parameters, so that
|
1638 |
# it can be loaded from an early exit:
|
1639 |
+
if not self.temp_equation_file:
|
1640 |
+
self._checkpoint()
|
1641 |
|
1642 |
# Perform the search:
|
1643 |
self._run(X, y, mutated_params, weights=weights, seed=seed)
|
1644 |
|
1645 |
# Then, after fit, we save again, so the pickle file contains
|
1646 |
# the equations:
|
1647 |
+
if not self.temp_equation_file:
|
1648 |
+
self._checkpoint()
|
1649 |
|
1650 |
return self
|
1651 |
|
|
|
2079 |
return selector.get_support(indices=True)
|
2080 |
|
2081 |
|
2082 |
+
def _csv_filename_to_pkl_filename(csv_filename) -> str:
|
2083 |
+
# Assume that the csv filename is of the form "foo.csv"
|
2084 |
+
dirname = str(os.path.dirname(csv_filename))
|
2085 |
+
basename = str(os.path.basename(csv_filename))
|
2086 |
+
base = str(os.path.splitext(basename)[0])
|
2087 |
+
|
2088 |
+
pkl_basename = base + ".pkl"
|
2089 |
+
|
2090 |
+
return os.path.join(dirname, pkl_basename)
|
2091 |
+
|
2092 |
+
|
2093 |
def load(
|
2094 |
equation_file,
|
2095 |
*,
|
|
|
2107 |
Parameters
|
2108 |
----------
|
2109 |
equation_file : str
|
2110 |
+
Path to a csv file containing equations, or a pickle file
|
2111 |
+
containing the model.
|
2112 |
|
2113 |
binary_operators : list[str], default=["+", "-", "*", "/"]
|
2114 |
The same binary operators used when creating the model.
|
|
|
2137 |
model : PySRRegressor
|
2138 |
The model with fitted equations.
|
2139 |
"""
|
2140 |
+
if os.path.splitext(equation_file)[1] != ".pkl":
|
2141 |
+
pkl_filename = _csv_filename_to_pkl_filename(equation_file)
|
2142 |
+
else:
|
2143 |
+
pkl_filename = equation_file
|
2144 |
+
|
2145 |
# Try to load model from <equation_file>.pkl
|
2146 |
+
print(f"Checking if {pkl_filename} exists...")
|
2147 |
+
if os.path.exists(pkl_filename):
|
2148 |
+
print(f"Loading model from {pkl_filename}")
|
2149 |
assert binary_operators is None
|
2150 |
assert unary_operators is None
|
2151 |
assert n_features_in is None
|
2152 |
+
with open(pkl_filename, "rb") as f:
|
2153 |
model = pkl.load(f)
|
2154 |
# Update any parameters if necessary, such as
|
2155 |
# extra_sympy_mappings:
|
|
|
2161 |
|
2162 |
# Else, we re-create it.
|
2163 |
print(
|
2164 |
+
f"{equation_file} does not exist, " "so we must create the model from scratch."
|
|
|
2165 |
)
|
2166 |
assert binary_operators is not None
|
2167 |
assert unary_operators is not None
|
test/test.py
CHANGED
@@ -5,7 +5,11 @@ import unittest
|
|
5 |
import numpy as np
|
6 |
from sklearn import model_selection
|
7 |
from pysr import PySRRegressor, load
|
8 |
-
from pysr.sr import
|
|
|
|
|
|
|
|
|
9 |
from sklearn.utils.estimator_checks import check_estimator
|
10 |
import sympy
|
11 |
import pandas as pd
|
@@ -341,6 +345,7 @@ class TestPipeline(unittest.TestCase):
|
|
341 |
if os.path.exists(file_to_delete):
|
342 |
os.remove(file_to_delete)
|
343 |
|
|
|
344 |
model3 = load(
|
345 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
346 |
)
|
@@ -430,6 +435,20 @@ class TestFeatureSelection(unittest.TestCase):
|
|
430 |
class TestMiscellaneous(unittest.TestCase):
|
431 |
"""Test miscellaneous functions."""
|
432 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
def test_deprecation(self):
|
434 |
"""Ensure that deprecation works as expected.
|
435 |
|
|
|
5 |
import numpy as np
|
6 |
from sklearn import model_selection
|
7 |
from pysr import PySRRegressor, load
|
8 |
+
from pysr.sr import (
|
9 |
+
run_feature_selection,
|
10 |
+
_handle_feature_selection,
|
11 |
+
_csv_filename_to_pkl_filename,
|
12 |
+
)
|
13 |
from sklearn.utils.estimator_checks import check_estimator
|
14 |
import sympy
|
15 |
import pandas as pd
|
|
|
345 |
if os.path.exists(file_to_delete):
|
346 |
os.remove(file_to_delete)
|
347 |
|
348 |
+
pickle_file = rand_dir / "equations.pkl"
|
349 |
model3 = load(
|
350 |
model.equation_file_, extra_sympy_mappings={"sq": lambda x: x**2}
|
351 |
)
|
|
|
435 |
class TestMiscellaneous(unittest.TestCase):
|
436 |
"""Test miscellaneous functions."""
|
437 |
|
438 |
+
def test_csv_to_pkl_conversion(self):
|
439 |
+
"""Test that csv filename to pkl filename works as expected."""
|
440 |
+
tmpdir = Path(tempfile.mkdtemp())
|
441 |
+
equation_file = tmpdir / "equations.389479384.28378374.csv"
|
442 |
+
expected_pkl_file = tmpdir / "equations.389479384.28378374.pkl"
|
443 |
+
|
444 |
+
# First, test inputting the paths:
|
445 |
+
test_pkl_file = _csv_filename_to_pkl_filename(equation_file)
|
446 |
+
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
447 |
+
|
448 |
+
# Next, test inputting the strings.
|
449 |
+
test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
|
450 |
+
self.assertEqual(test_pkl_file, str(expected_pkl_file))
|
451 |
+
|
452 |
def test_deprecation(self):
|
453 |
"""Ensure that deprecation works as expected.
|
454 |
|