Spaces:
Sleeping
Sleeping
"""Adapted from: | |
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py | |
""" | |
import sys | |
from typing import Any, Dict, Optional | |
import pytest | |
import torch | |
from packaging.version import Version | |
from pkg_resources import get_distribution | |
from pytest import MarkDecorator | |
from tests.helpers.package_available import ( | |
_COMET_AVAILABLE, | |
_DEEPSPEED_AVAILABLE, | |
_FAIRSCALE_AVAILABLE, | |
_IS_WINDOWS, | |
_MLFLOW_AVAILABLE, | |
_NEPTUNE_AVAILABLE, | |
_SH_AVAILABLE, | |
_TPU_AVAILABLE, | |
_WANDB_AVAILABLE, | |
) | |
class RunIf: | |
"""RunIf wrapper for conditional skipping of tests. | |
Fully compatible with `@pytest.mark`. | |
Example: | |
```python | |
@RunIf(min_torch="1.8") | |
@pytest.mark.parametrize("arg1", [1.0, 2.0]) | |
def test_wrapper(arg1): | |
assert arg1 > 0 | |
``` | |
""" | |
def __new__( | |
cls, | |
min_gpus: int = 0, | |
min_torch: Optional[str] = None, | |
max_torch: Optional[str] = None, | |
min_python: Optional[str] = None, | |
skip_windows: bool = False, | |
sh: bool = False, | |
tpu: bool = False, | |
fairscale: bool = False, | |
deepspeed: bool = False, | |
wandb: bool = False, | |
neptune: bool = False, | |
comet: bool = False, | |
mlflow: bool = False, | |
**kwargs: Dict[Any, Any], | |
) -> MarkDecorator: | |
"""Creates a new `@RunIf` `MarkDecorator` decorator. | |
:param min_gpus: Min number of GPUs required to run test. | |
:param min_torch: Minimum pytorch version to run test. | |
:param max_torch: Maximum pytorch version to run test. | |
:param min_python: Minimum python version required to run test. | |
:param skip_windows: Skip test for Windows platform. | |
:param tpu: If TPU is available. | |
:param sh: If `sh` module is required to run the test. | |
:param fairscale: If `fairscale` module is required to run the test. | |
:param deepspeed: If `deepspeed` module is required to run the test. | |
:param wandb: If `wandb` module is required to run the test. | |
:param neptune: If `neptune` module is required to run the test. | |
:param comet: If `comet` module is required to run the test. | |
:param mlflow: If `mlflow` module is required to run the test. | |
:param kwargs: Native `pytest.mark.skipif` keyword arguments. | |
""" | |
conditions = [] | |
reasons = [] | |
if min_gpus: | |
conditions.append(torch.cuda.device_count() < min_gpus) | |
reasons.append(f"GPUs>={min_gpus}") | |
if min_torch: | |
torch_version = get_distribution("torch").version | |
conditions.append(Version(torch_version) < Version(min_torch)) | |
reasons.append(f"torch>={min_torch}") | |
if max_torch: | |
torch_version = get_distribution("torch").version | |
conditions.append(Version(torch_version) >= Version(max_torch)) | |
reasons.append(f"torch<{max_torch}") | |
if min_python: | |
py_version = ( | |
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" | |
) | |
conditions.append(Version(py_version) < Version(min_python)) | |
reasons.append(f"python>={min_python}") | |
if skip_windows: | |
conditions.append(_IS_WINDOWS) | |
reasons.append("does not run on Windows") | |
if tpu: | |
conditions.append(not _TPU_AVAILABLE) | |
reasons.append("TPU") | |
if sh: | |
conditions.append(not _SH_AVAILABLE) | |
reasons.append("sh") | |
if fairscale: | |
conditions.append(not _FAIRSCALE_AVAILABLE) | |
reasons.append("fairscale") | |
if deepspeed: | |
conditions.append(not _DEEPSPEED_AVAILABLE) | |
reasons.append("deepspeed") | |
if wandb: | |
conditions.append(not _WANDB_AVAILABLE) | |
reasons.append("wandb") | |
if neptune: | |
conditions.append(not _NEPTUNE_AVAILABLE) | |
reasons.append("neptune") | |
if comet: | |
conditions.append(not _COMET_AVAILABLE) | |
reasons.append("comet") | |
if mlflow: | |
conditions.append(not _MLFLOW_AVAILABLE) | |
reasons.append("mlflow") | |
reasons = [rs for cond, rs in zip(conditions, reasons) if cond] | |
return pytest.mark.skipif( | |
condition=any(conditions), | |
reason=f"Requires: [{' + '.join(reasons)}]", | |
**kwargs, | |
) | |