Spaces:
Running
Running
File size: 2,622 Bytes
b271a60 ad8ed14 ef66f4a 139b8d0 976f8d8 139b8d0 3876e75 618a3f8 24a4349 618a3f8 92eb30b 618a3f8 3555cfd 3876e75 01ec39f 3876e75 01ec39f f145620 3876e75 2a98f83 d8d6e2b 3876e75 139b8d0 3555cfd 24a4349 3555cfd 2a98f83 b271a60 2a98f83 98fa83e 24a4349 98fa83e ef66f4a 2a98f83 8685680 ef66f4a 8685680 ef66f4a 8685680 ef66f4a 8685680 ef66f4a 24a4349 ef66f4a 92eb30b ef66f4a 2a98f83 ef66f4a b271a60 ef66f4a ad8ed14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import fnmatch
import sys
import unittest
import warnings
import click
from ..test import (
get_runtests_cli,
runtests,
runtests_dev,
runtests_jax,
runtests_startup,
runtests_torch,
)
@click.group("pysr")
@click.pass_context
def pysr(context):
ctx = context
@pysr.command("install", help="DEPRECATED (dependencies are now installed at import).")
@click.option(
"-p",
"julia_project",
"--project",
default=None,
type=str,
)
@click.option("-q", "--quiet", is_flag=True, default=False, help="Disable logging.")
@click.option(
"--precompile",
"precompile",
flag_value=True,
default=None,
)
@click.option(
"--no-precompile",
"precompile",
flag_value=False,
default=None,
)
def _install(julia_project, quiet, precompile):
warnings.warn(
"This command is deprecated. Julia dependencies are now installed at first import."
)
TEST_OPTIONS = {"main", "jax", "torch", "cli", "dev", "startup"}
@pysr.command("test")
@click.argument("tests", nargs=1)
@click.option(
"-k",
"expressions",
multiple=True,
type=str,
help="Filter expressions to select specific tests.",
)
def _tests(tests, expressions):
"""Run parts of the PySR test suite.
Choose from main, jax, torch, cli, dev, and startup. You can give multiple tests, separated by commas.
"""
test_cases = []
for test in tests.split(","):
if test == "main":
test_cases.extend(runtests(just_tests=True))
elif test == "jax":
test_cases.extend(runtests_jax(just_tests=True))
elif test == "torch":
test_cases.extend(runtests_torch(just_tests=True))
elif test == "cli":
runtests_cli = get_runtests_cli()
test_cases.extend(runtests_cli(just_tests=True))
elif test == "dev":
test_cases.extend(runtests_dev(just_tests=True))
elif test == "startup":
test_cases.extend(runtests_startup(just_tests=True))
else:
warnings.warn(f"Invalid test {test}. Skipping.")
loader = unittest.TestLoader()
suite = unittest.TestSuite()
for test_case in test_cases:
loaded_tests = loader.loadTestsFromTestCase(test_case)
for test in loaded_tests:
if len(expressions) == 0 or any(
fnmatch.fnmatch(test.id(), "*" + expression + "*")
for expression in expressions
):
suite.addTest(test)
runner = unittest.TextTestRunner()
results = runner.run(suite)
if not results.wasSuccessful():
sys.exit(1)
|