File size: 1,677 Bytes
139b8d0
976f8d8
139b8d0
3876e75
618a3f8
 
 
 
 
 
 
3555cfd
3876e75
01ec39f
3876e75
01ec39f
f145620
3876e75
 
2a98f83
d8d6e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3876e75
139b8d0
 
 
3555cfd
 
9b3be67
3555cfd
 
2a98f83
 
3555cfd
2a98f83
98fa83e
9b3be67
98fa83e
2a98f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import warnings

import click

from ..test import (
    get_runtests_cli,
    runtests,
    runtests_jax,
    runtests_torch,
    runtests_warm_start,
)


@click.group("pysr")
@click.pass_context
def pysr(context):
    ctx = context


@pysr.command("install", help="DEPRECATED (dependencies are now installed at import).")
@click.option(
    "-p",
    "julia_project",
    "--project",
    default=None,
    type=str,
)
@click.option("-q", "--quiet", is_flag=True, default=False, help="Disable logging.")
@click.option(
    "--precompile",
    "precompile",
    flag_value=True,
    default=None,
)
@click.option(
    "--no-precompile",
    "precompile",
    flag_value=False,
    default=None,
)
def _install(julia_project, quiet, precompile):
    warnings.warn(
        "This command is deprecated. Julia dependencies are now installed at first import."
    )


TEST_OPTIONS = {"main", "jax", "torch", "cli", "warm-start"}


@pysr.command("test")
@click.argument("tests", nargs=1)
def _tests(tests):
    """Run parts of the PySR test suite.

    Choose from main, jax, torch, cli, and warm-start. You can give multiple tests, separated by commas.
    """
    for test in tests.split(","):
        if test in TEST_OPTIONS:
            if test == "main":
                runtests()
            elif test == "jax":
                runtests_jax()
            elif test == "torch":
                runtests_torch()
            elif test == "cli":
                runtests_cli = get_runtests_cli()
                runtests_cli()
            elif test == "warm-start":
                runtests_warm_start()
        else:
            warnings.warn(f"Invalid test {test}. Skipping.")