File size: 1,977 Bytes
139b8d0
976f8d8
139b8d0
3876e75
31ecc71
3555cfd
3876e75
01ec39f
3876e75
01ec39f
f145620
3876e75
 
d8d6e2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3876e75
139b8d0
 
 
3555cfd
 
31ecc71
3555cfd
 
 
98fa83e
3555cfd
98fa83e
 
31ecc71
98fa83e
3555cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import warnings

import click

from ..test import get_runtests_cli, runtests, runtests_jax, runtests_torch


@click.group("pysr")
@click.pass_context
def pysr(context):
    ctx = context


@pysr.command("install", help="Install Julia dependencies for PySR.")
@click.option(
    "-p",
    "julia_project",
    "--project",
    default=None,
    type=str,
    help="Install in a specific Julia project (e.g., a local copy of SymbolicRegression.jl).",
    metavar="PROJECT_DIRECTORY",
)
@click.option("-q", "--quiet", is_flag=True, default=False, help="Disable logging.")
@click.option(
    "--precompile",
    "precompile",
    flag_value=True,
    default=None,
    help="Force precompilation of Julia libraries.",
)
@click.option(
    "--no-precompile",
    "precompile",
    flag_value=False,
    default=None,
    help="Disable precompilation.",
)
def _install(julia_project, quiet, precompile):
    warnings.warn(
        "This command is deprecated. Julia dependencies are now installed at first import."
    )


TEST_OPTIONS = {"main", "jax", "torch", "cli"}


@pysr.command("test", help="Run PySR test suite.")
@click.argument("tests", nargs=-1)
def _tests(tests):
    """Run part of the PySR test suite.

    Choose from main, jax, torch, and cli.
    """
    if len(tests) == 0:
        raise click.UsageError(
            "At least one test must be specified. "
            + "The following are available: "
            + ", ".join(TEST_OPTIONS)
            + "."
        )
    else:
        for test in tests:
            if test in TEST_OPTIONS:
                if test == "main":
                    runtests()
                elif test == "jax":
                    runtests_jax()
                elif test == "torch":
                    runtests_torch()
                elif test == "cli":
                    runtests_cli = get_runtests_cli()
                    runtests_cli()
            else:
                warnings.warn(f"Invalid test {test}. Skipping.")