Spaces:
Sleeping
Sleeping
File size: 4,641 Bytes
fa7be76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""Adapted from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py
"""
import sys
from typing import Any, Dict, Optional
import pytest
import torch
from packaging.version import Version
from pkg_resources import get_distribution
from pytest import MarkDecorator
from tests.helpers.package_available import (
_COMET_AVAILABLE,
_DEEPSPEED_AVAILABLE,
_FAIRSCALE_AVAILABLE,
_IS_WINDOWS,
_MLFLOW_AVAILABLE,
_NEPTUNE_AVAILABLE,
_SH_AVAILABLE,
_TPU_AVAILABLE,
_WANDB_AVAILABLE,
)
class RunIf:
"""RunIf wrapper for conditional skipping of tests.
Fully compatible with `@pytest.mark`.
Example:
```python
@RunIf(min_torch="1.8")
@pytest.mark.parametrize("arg1", [1.0, 2.0])
def test_wrapper(arg1):
assert arg1 > 0
```
"""
def __new__(
cls,
min_gpus: int = 0,
min_torch: Optional[str] = None,
max_torch: Optional[str] = None,
min_python: Optional[str] = None,
skip_windows: bool = False,
sh: bool = False,
tpu: bool = False,
fairscale: bool = False,
deepspeed: bool = False,
wandb: bool = False,
neptune: bool = False,
comet: bool = False,
mlflow: bool = False,
**kwargs: Dict[Any, Any],
) -> MarkDecorator:
"""Creates a new `@RunIf` `MarkDecorator` decorator.
:param min_gpus: Min number of GPUs required to run test.
:param min_torch: Minimum pytorch version to run test.
:param max_torch: Maximum pytorch version to run test.
:param min_python: Minimum python version required to run test.
:param skip_windows: Skip test for Windows platform.
:param tpu: If TPU is available.
:param sh: If `sh` module is required to run the test.
:param fairscale: If `fairscale` module is required to run the test.
:param deepspeed: If `deepspeed` module is required to run the test.
:param wandb: If `wandb` module is required to run the test.
:param neptune: If `neptune` module is required to run the test.
:param comet: If `comet` module is required to run the test.
:param mlflow: If `mlflow` module is required to run the test.
:param kwargs: Native `pytest.mark.skipif` keyword arguments.
"""
conditions = []
reasons = []
if min_gpus:
conditions.append(torch.cuda.device_count() < min_gpus)
reasons.append(f"GPUs>={min_gpus}")
if min_torch:
torch_version = get_distribution("torch").version
conditions.append(Version(torch_version) < Version(min_torch))
reasons.append(f"torch>={min_torch}")
if max_torch:
torch_version = get_distribution("torch").version
conditions.append(Version(torch_version) >= Version(max_torch))
reasons.append(f"torch<{max_torch}")
if min_python:
py_version = (
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
conditions.append(Version(py_version) < Version(min_python))
reasons.append(f"python>={min_python}")
if skip_windows:
conditions.append(_IS_WINDOWS)
reasons.append("does not run on Windows")
if tpu:
conditions.append(not _TPU_AVAILABLE)
reasons.append("TPU")
if sh:
conditions.append(not _SH_AVAILABLE)
reasons.append("sh")
if fairscale:
conditions.append(not _FAIRSCALE_AVAILABLE)
reasons.append("fairscale")
if deepspeed:
conditions.append(not _DEEPSPEED_AVAILABLE)
reasons.append("deepspeed")
if wandb:
conditions.append(not _WANDB_AVAILABLE)
reasons.append("wandb")
if neptune:
conditions.append(not _NEPTUNE_AVAILABLE)
reasons.append("neptune")
if comet:
conditions.append(not _COMET_AVAILABLE)
reasons.append("comet")
if mlflow:
conditions.append(not _MLFLOW_AVAILABLE)
reasons.append("mlflow")
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
condition=any(conditions),
reason=f"Requires: [{' + '.join(reasons)}]",
**kwargs,
)
|