"""Adapted from: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py """ import sys from typing import Any, Dict, Optional import pytest import torch from packaging.version import Version from pkg_resources import get_distribution from pytest import MarkDecorator from tests.helpers.package_available import ( _COMET_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, _IS_WINDOWS, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _SH_AVAILABLE, _TPU_AVAILABLE, _WANDB_AVAILABLE, ) class RunIf: """RunIf wrapper for conditional skipping of tests. Fully compatible with `@pytest.mark`. Example: ```python @RunIf(min_torch="1.8") @pytest.mark.parametrize("arg1", [1.0, 2.0]) def test_wrapper(arg1): assert arg1 > 0 ``` """ def __new__( cls, min_gpus: int = 0, min_torch: Optional[str] = None, max_torch: Optional[str] = None, min_python: Optional[str] = None, skip_windows: bool = False, sh: bool = False, tpu: bool = False, fairscale: bool = False, deepspeed: bool = False, wandb: bool = False, neptune: bool = False, comet: bool = False, mlflow: bool = False, **kwargs: Dict[Any, Any], ) -> MarkDecorator: """Creates a new `@RunIf` `MarkDecorator` decorator. :param min_gpus: Min number of GPUs required to run test. :param min_torch: Minimum pytorch version to run test. :param max_torch: Maximum pytorch version to run test. :param min_python: Minimum python version required to run test. :param skip_windows: Skip test for Windows platform. :param tpu: If TPU is available. :param sh: If `sh` module is required to run the test. :param fairscale: If `fairscale` module is required to run the test. :param deepspeed: If `deepspeed` module is required to run the test. :param wandb: If `wandb` module is required to run the test. :param neptune: If `neptune` module is required to run the test. :param comet: If `comet` module is required to run the test. :param mlflow: If `mlflow` module is required to run the test. :param kwargs: Native `pytest.mark.skipif` keyword arguments. """ conditions = [] reasons = [] if min_gpus: conditions.append(torch.cuda.device_count() < min_gpus) reasons.append(f"GPUs>={min_gpus}") if min_torch: torch_version = get_distribution("torch").version conditions.append(Version(torch_version) < Version(min_torch)) reasons.append(f"torch>={min_torch}") if max_torch: torch_version = get_distribution("torch").version conditions.append(Version(torch_version) >= Version(max_torch)) reasons.append(f"torch<{max_torch}") if min_python: py_version = ( f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" ) conditions.append(Version(py_version) < Version(min_python)) reasons.append(f"python>={min_python}") if skip_windows: conditions.append(_IS_WINDOWS) reasons.append("does not run on Windows") if tpu: conditions.append(not _TPU_AVAILABLE) reasons.append("TPU") if sh: conditions.append(not _SH_AVAILABLE) reasons.append("sh") if fairscale: conditions.append(not _FAIRSCALE_AVAILABLE) reasons.append("fairscale") if deepspeed: conditions.append(not _DEEPSPEED_AVAILABLE) reasons.append("deepspeed") if wandb: conditions.append(not _WANDB_AVAILABLE) reasons.append("wandb") if neptune: conditions.append(not _NEPTUNE_AVAILABLE) reasons.append("neptune") if comet: conditions.append(not _COMET_AVAILABLE) reasons.append("comet") if mlflow: conditions.append(not _MLFLOW_AVAILABLE) reasons.append("mlflow") reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs, )