Spaces:
Running
Running
MilesCranmer
commited on
Add warm_start test
Browse files- pysr/_cli/main.py +11 -3
- pysr/test/__init__.py +2 -0
- pysr/test/params.py +8 -0
- pysr/test/test.py +7 -7
- pysr/test/test_warm_start.py +104 -0
pysr/_cli/main.py
CHANGED
@@ -2,7 +2,13 @@ import warnings
|
|
2 |
|
3 |
import click
|
4 |
|
5 |
-
from ..test import
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
@click.group("pysr")
|
@@ -42,7 +48,7 @@ def _install(julia_project, quiet, precompile):
|
|
42 |
)
|
43 |
|
44 |
|
45 |
-
TEST_OPTIONS = {"main", "jax", "torch", "cli"}
|
46 |
|
47 |
|
48 |
@pysr.command("test", help="Run PySR test suite.")
|
@@ -50,7 +56,7 @@ TEST_OPTIONS = {"main", "jax", "torch", "cli"}
|
|
50 |
def _tests(tests):
|
51 |
"""Run part of the PySR test suite.
|
52 |
|
53 |
-
Choose from main, jax, torch, and
|
54 |
"""
|
55 |
if len(tests) == 0:
|
56 |
raise click.UsageError(
|
@@ -71,5 +77,7 @@ def _tests(tests):
|
|
71 |
elif test == "cli":
|
72 |
runtests_cli = get_runtests_cli()
|
73 |
runtests_cli()
|
|
|
|
|
74 |
else:
|
75 |
warnings.warn(f"Invalid test {test}. Skipping.")
|
|
|
2 |
|
3 |
import click
|
4 |
|
5 |
+
from ..test import (
|
6 |
+
get_runtests_cli,
|
7 |
+
runtests,
|
8 |
+
runtests_jax,
|
9 |
+
runtests_torch,
|
10 |
+
runtests_warm_start,
|
11 |
+
)
|
12 |
|
13 |
|
14 |
@click.group("pysr")
|
|
|
48 |
)
|
49 |
|
50 |
|
51 |
+
TEST_OPTIONS = {"main", "jax", "torch", "cli", "warm_start"}
|
52 |
|
53 |
|
54 |
@pysr.command("test", help="Run PySR test suite.")
|
|
|
56 |
def _tests(tests):
|
57 |
"""Run part of the PySR test suite.
|
58 |
|
59 |
+
Choose from main, jax, torch, cli, and warm_start.
|
60 |
"""
|
61 |
if len(tests) == 0:
|
62 |
raise click.UsageError(
|
|
|
77 |
elif test == "cli":
|
78 |
runtests_cli = get_runtests_cli()
|
79 |
runtests_cli()
|
80 |
+
elif test == "warm_start":
|
81 |
+
runtests_warm_start()
|
82 |
else:
|
83 |
warnings.warn(f"Invalid test {test}. Skipping.")
|
pysr/test/__init__.py
CHANGED
@@ -2,10 +2,12 @@ from .test import runtests
|
|
2 |
from .test_cli import get_runtests as get_runtests_cli
|
3 |
from .test_jax import runtests as runtests_jax
|
4 |
from .test_torch import runtests as runtests_torch
|
|
|
5 |
|
6 |
__all__ = [
|
7 |
"runtests",
|
8 |
"runtests_jax",
|
9 |
"runtests_torch",
|
10 |
"get_runtests_cli",
|
|
|
11 |
]
|
|
|
2 |
from .test_cli import get_runtests as get_runtests_cli
|
3 |
from .test_jax import runtests as runtests_jax
|
4 |
from .test_torch import runtests as runtests_torch
|
5 |
+
from .test_warm_start import runtests as runtests_warm_start
|
6 |
|
7 |
__all__ = [
|
8 |
"runtests",
|
9 |
"runtests_jax",
|
10 |
"runtests_torch",
|
11 |
"get_runtests_cli",
|
12 |
+
"runtests_warm_start",
|
13 |
]
|
pysr/test/params.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
|
3 |
+
from .. import PySRRegressor
|
4 |
+
|
5 |
+
DEFAULT_PARAMS = inspect.signature(PySRRegressor.__init__).parameters
|
6 |
+
DEFAULT_NITERATIONS = DEFAULT_PARAMS["niterations"].default
|
7 |
+
DEFAULT_POPULATIONS = DEFAULT_PARAMS["populations"].default
|
8 |
+
DEFAULT_NCYCLES = DEFAULT_PARAMS["ncyclesperiteration"].default
|
pysr/test/test.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import inspect
|
2 |
import os
|
3 |
import pickle as pkl
|
4 |
import tempfile
|
@@ -12,16 +11,17 @@ import pandas as pd
|
|
12 |
import sympy
|
13 |
from sklearn.utils.estimator_checks import check_estimator
|
14 |
|
15 |
-
from .. import PySRRegressor
|
16 |
from ..export_latex import sympy2latex
|
17 |
from ..feature_selection import _handle_feature_selection, run_feature_selection
|
18 |
from ..sr import _check_assertions, _process_constraints, idx_model_selection
|
19 |
from ..utils import _csv_filename_to_pkl_filename
|
20 |
-
|
21 |
-
|
22 |
-
DEFAULT_NITERATIONS
|
23 |
-
|
24 |
-
|
|
|
25 |
|
26 |
|
27 |
class TestPipeline(unittest.TestCase):
|
|
|
|
|
1 |
import os
|
2 |
import pickle as pkl
|
3 |
import tempfile
|
|
|
11 |
import sympy
|
12 |
from sklearn.utils.estimator_checks import check_estimator
|
13 |
|
14 |
+
from .. import PySRRegressor
|
15 |
from ..export_latex import sympy2latex
|
16 |
from ..feature_selection import _handle_feature_selection, run_feature_selection
|
17 |
from ..sr import _check_assertions, _process_constraints, idx_model_selection
|
18 |
from ..utils import _csv_filename_to_pkl_filename
|
19 |
+
from .params import (
|
20 |
+
DEFAULT_NCYCLES,
|
21 |
+
DEFAULT_NITERATIONS,
|
22 |
+
DEFAULT_PARAMS,
|
23 |
+
DEFAULT_POPULATIONS,
|
24 |
+
)
|
25 |
|
26 |
|
27 |
class TestPipeline(unittest.TestCase):
|
pysr/test/test_warm_start.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import tempfile
|
3 |
+
import textwrap
|
4 |
+
import unittest
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from .. import PySRRegressor
|
10 |
+
from .params import (
|
11 |
+
DEFAULT_NCYCLES,
|
12 |
+
DEFAULT_NITERATIONS,
|
13 |
+
DEFAULT_PARAMS,
|
14 |
+
DEFAULT_POPULATIONS,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class TestWarmStart(unittest.TestCase):
|
19 |
+
def setUp(self):
|
20 |
+
# Using inspect,
|
21 |
+
# get default niterations from PySRRegressor, and double them:
|
22 |
+
self.default_test_kwargs = dict(
|
23 |
+
progress=False,
|
24 |
+
model_selection="accuracy",
|
25 |
+
niterations=DEFAULT_NITERATIONS * 2,
|
26 |
+
populations=DEFAULT_POPULATIONS * 2,
|
27 |
+
temp_equation_file=True,
|
28 |
+
)
|
29 |
+
self.rstate = np.random.RandomState(0)
|
30 |
+
self.X = self.rstate.randn(100, 5)
|
31 |
+
|
32 |
+
def test_warm_start_from_file(self):
|
33 |
+
"""Test that we can warm start in another process."""
|
34 |
+
with tempfile.TemporaryDirectory() as tmpdirname:
|
35 |
+
model = PySRRegressor(
|
36 |
+
**self.default_test_kwargs,
|
37 |
+
unary_operators=["cos"],
|
38 |
+
)
|
39 |
+
model.warm_start = True
|
40 |
+
model.temp_equation_file = False
|
41 |
+
model.equation_file = Path(tmpdirname) / "equations.csv"
|
42 |
+
model.deterministic = True
|
43 |
+
model.multithreading = False
|
44 |
+
model.random_state = 0
|
45 |
+
model.procs = 0
|
46 |
+
model.early_stop_condition = 1e-10
|
47 |
+
|
48 |
+
rstate = np.random.RandomState(0)
|
49 |
+
X = rstate.randn(100, 2)
|
50 |
+
y = np.cos(X[:, 0]) ** 2
|
51 |
+
model.fit(X, y)
|
52 |
+
|
53 |
+
best_loss = model.equations_.iloc[-1]["loss"]
|
54 |
+
|
55 |
+
# Save X and y to a file:
|
56 |
+
X_file = Path(tmpdirname) / "X.npy"
|
57 |
+
y_file = Path(tmpdirname) / "y.npy"
|
58 |
+
np.save(X_file, X)
|
59 |
+
np.save(y_file, y)
|
60 |
+
# Now, create a new process and warm start from the file:
|
61 |
+
result = subprocess.run(
|
62 |
+
[
|
63 |
+
"python",
|
64 |
+
"-c",
|
65 |
+
textwrap.dedent(
|
66 |
+
f"""
|
67 |
+
from pysr import PySRRegressor
|
68 |
+
import numpy as np
|
69 |
+
|
70 |
+
X = np.load("{X_file}")
|
71 |
+
y = np.load("{y_file}")
|
72 |
+
|
73 |
+
print("Loading model from file")
|
74 |
+
model = PySRRegressor.from_file("{model.equation_file}")
|
75 |
+
|
76 |
+
assert model.julia_state_ is not None
|
77 |
+
|
78 |
+
model.warm_start = True
|
79 |
+
model.niterations = 0
|
80 |
+
model.max_evals = 0
|
81 |
+
model.ncyclesperiteration = 0
|
82 |
+
|
83 |
+
model.fit(X, y)
|
84 |
+
|
85 |
+
best_loss = model.equations_.iloc[-1]["loss"]
|
86 |
+
|
87 |
+
assert best_loss <= {best_loss}
|
88 |
+
"""
|
89 |
+
),
|
90 |
+
],
|
91 |
+
stdout=subprocess.PIPE,
|
92 |
+
stderr=subprocess.PIPE,
|
93 |
+
)
|
94 |
+
self.assertEqual(result.returncode, 0)
|
95 |
+
self.assertIn("Loading model from file", result.stdout.decode())
|
96 |
+
self.assertIn("Started!", result.stderr.decode())
|
97 |
+
|
98 |
+
|
99 |
+
def runtests():
|
100 |
+
suite = unittest.TestSuite()
|
101 |
+
loader = unittest.TestLoader()
|
102 |
+
suite.addTests(loader.loadTestsFromTestCase(TestWarmStart))
|
103 |
+
runner = unittest.TextTestRunner()
|
104 |
+
return runner.run(suite)
|