File size: 4,641 Bytes
fa7be76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Adapted from:



https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py

"""

import sys
from typing import Any, Dict, Optional

import pytest
import torch
from packaging.version import Version
from pkg_resources import get_distribution
from pytest import MarkDecorator

from tests.helpers.package_available import (
    _COMET_AVAILABLE,
    _DEEPSPEED_AVAILABLE,
    _FAIRSCALE_AVAILABLE,
    _IS_WINDOWS,
    _MLFLOW_AVAILABLE,
    _NEPTUNE_AVAILABLE,
    _SH_AVAILABLE,
    _TPU_AVAILABLE,
    _WANDB_AVAILABLE,
)


class RunIf:
    """RunIf wrapper for conditional skipping of tests.



    Fully compatible with `@pytest.mark`.



    Example:



    ```python

        @RunIf(min_torch="1.8")

        @pytest.mark.parametrize("arg1", [1.0, 2.0])

        def test_wrapper(arg1):

            assert arg1 > 0

    ```

    """

    def __new__(

        cls,

        min_gpus: int = 0,

        min_torch: Optional[str] = None,

        max_torch: Optional[str] = None,

        min_python: Optional[str] = None,

        skip_windows: bool = False,

        sh: bool = False,

        tpu: bool = False,

        fairscale: bool = False,

        deepspeed: bool = False,

        wandb: bool = False,

        neptune: bool = False,

        comet: bool = False,

        mlflow: bool = False,

        **kwargs: Dict[Any, Any],

    ) -> MarkDecorator:
        """Creates a new `@RunIf` `MarkDecorator` decorator.



        :param min_gpus: Min number of GPUs required to run test.

        :param min_torch: Minimum pytorch version to run test.

        :param max_torch: Maximum pytorch version to run test.

        :param min_python: Minimum python version required to run test.

        :param skip_windows: Skip test for Windows platform.

        :param tpu: If TPU is available.

        :param sh: If `sh` module is required to run the test.

        :param fairscale: If `fairscale` module is required to run the test.

        :param deepspeed: If `deepspeed` module is required to run the test.

        :param wandb: If `wandb` module is required to run the test.

        :param neptune: If `neptune` module is required to run the test.

        :param comet: If `comet` module is required to run the test.

        :param mlflow: If `mlflow` module is required to run the test.

        :param kwargs: Native `pytest.mark.skipif` keyword arguments.

        """
        conditions = []
        reasons = []

        if min_gpus:
            conditions.append(torch.cuda.device_count() < min_gpus)
            reasons.append(f"GPUs>={min_gpus}")

        if min_torch:
            torch_version = get_distribution("torch").version
            conditions.append(Version(torch_version) < Version(min_torch))
            reasons.append(f"torch>={min_torch}")

        if max_torch:
            torch_version = get_distribution("torch").version
            conditions.append(Version(torch_version) >= Version(max_torch))
            reasons.append(f"torch<{max_torch}")

        if min_python:
            py_version = (
                f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
            )
            conditions.append(Version(py_version) < Version(min_python))
            reasons.append(f"python>={min_python}")

        if skip_windows:
            conditions.append(_IS_WINDOWS)
            reasons.append("does not run on Windows")

        if tpu:
            conditions.append(not _TPU_AVAILABLE)
            reasons.append("TPU")

        if sh:
            conditions.append(not _SH_AVAILABLE)
            reasons.append("sh")

        if fairscale:
            conditions.append(not _FAIRSCALE_AVAILABLE)
            reasons.append("fairscale")

        if deepspeed:
            conditions.append(not _DEEPSPEED_AVAILABLE)
            reasons.append("deepspeed")

        if wandb:
            conditions.append(not _WANDB_AVAILABLE)
            reasons.append("wandb")

        if neptune:
            conditions.append(not _NEPTUNE_AVAILABLE)
            reasons.append("neptune")

        if comet:
            conditions.append(not _COMET_AVAILABLE)
            reasons.append("comet")

        if mlflow:
            conditions.append(not _MLFLOW_AVAILABLE)
            reasons.append("mlflow")

        reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
        return pytest.mark.skipif(
            condition=any(conditions),
            reason=f"Requires: [{' + '.join(reasons)}]",
            **kwargs,
        )