File size: 2,622 Bytes
b271a60
ad8ed14
ef66f4a
139b8d0
976f8d8
139b8d0
3876e75
618a3f8
 
 
24a4349
618a3f8
92eb30b
618a3f8
 
3555cfd
3876e75
01ec39f
3876e75
01ec39f
f145620
3876e75
 
2a98f83
d8d6e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3876e75
139b8d0
 
 
3555cfd
 
24a4349
3555cfd
 
2a98f83
 
b271a60
 
 
 
 
 
 
 
2a98f83
98fa83e
24a4349
98fa83e
ef66f4a
2a98f83
8685680
ef66f4a
8685680
ef66f4a
8685680
ef66f4a
8685680
 
ef66f4a
24a4349
ef66f4a
92eb30b
ef66f4a
2a98f83
 
ef66f4a
 
 
 
b271a60
 
 
 
 
 
 
 
ef66f4a
ad8ed14
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import fnmatch
import sys
import unittest
import warnings

import click

from ..test import (
    get_runtests_cli,
    runtests,
    runtests_dev,
    runtests_jax,
    runtests_startup,
    runtests_torch,
)


@click.group("pysr")
@click.pass_context
def pysr(context):
    ctx = context


@pysr.command("install", help="DEPRECATED (dependencies are now installed at import).")
@click.option(
    "-p",
    "julia_project",
    "--project",
    default=None,
    type=str,
)
@click.option("-q", "--quiet", is_flag=True, default=False, help="Disable logging.")
@click.option(
    "--precompile",
    "precompile",
    flag_value=True,
    default=None,
)
@click.option(
    "--no-precompile",
    "precompile",
    flag_value=False,
    default=None,
)
def _install(julia_project, quiet, precompile):
    warnings.warn(
        "This command is deprecated. Julia dependencies are now installed at first import."
    )


TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"}


@pysr.command("test")
@click.argument("tests", nargs=1)
@click.option(
    "-k",
    "expressions",
    multiple=True,
    type=str,
    help="Filter expressions to select specific tests.",
)
def _tests(tests, expressions):
    """Run parts of the PySR test suite.

    Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
    """
    test_cases = []
    for test in tests.split(","):
        if test == "main":
            test_cases.extend(runtests(just_tests=True))
        elif test == "jax":
            test_cases.extend(runtests_jax(just_tests=True))
        elif test == "torch":
            test_cases.extend(runtests_torch(just_tests=True))
        elif test == "cli":
            runtests_cli = get_runtests_cli()
            test_cases.extend(runtests_cli(just_tests=True))
        elif test == "dev":
            test_cases.extend(runtests_dev(just_tests=True))
        elif test == "startup":
            test_cases.extend(runtests_startup(just_tests=True))
        else:
            warnings.warn(f"Invalid test {test}. Skipping.")

    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    for test_case in test_cases:
        loaded_tests = loader.loadTestsFromTestCase(test_case)
        for test in loaded_tests:
            if len(expressions) == 0 or any(
                fnmatch.fnmatch(test.id(), "*" + expression + "*")
                for expression in expressions
            ):
                suite.addTest(test)

    runner = unittest.TextTestRunner()
    results = runner.run(suite)

    if not results.wasSuccessful():
        sys.exit(1)