|
|
|
|
|
|
|
import builtins |
|
import platform |
|
import sys |
|
from contextlib import suppress |
|
from functools import wraps |
|
from os import environ |
|
from unittest import SkipTest |
|
|
|
import joblib |
|
import numpy as np |
|
import pytest |
|
from _pytest.doctest import DoctestItem |
|
from threadpoolctl import threadpool_limits |
|
|
|
from sklearn import set_config |
|
from sklearn._min_dependencies import PYTEST_MIN_VERSION |
|
from sklearn.datasets import ( |
|
fetch_20newsgroups, |
|
fetch_20newsgroups_vectorized, |
|
fetch_california_housing, |
|
fetch_covtype, |
|
fetch_kddcup99, |
|
fetch_lfw_pairs, |
|
fetch_lfw_people, |
|
fetch_olivetti_faces, |
|
fetch_rcv1, |
|
fetch_species_distributions, |
|
) |
|
from sklearn.utils._testing import get_pytest_filterwarning_lines |
|
from sklearn.utils.fixes import ( |
|
_IS_32BIT, |
|
np_base_version, |
|
parse_version, |
|
sp_version, |
|
) |
|
|
|
if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION): |
|
raise ImportError( |
|
f"Your version of pytest is too old. Got version {pytest.__version__}, you" |
|
f" should have pytest >= {PYTEST_MIN_VERSION} installed." |
|
) |
|
|
|
scipy_datasets_require_network = sp_version >= parse_version("1.10") |
|
|
|
|
|
def raccoon_face_or_skip(): |
|
|
|
if scipy_datasets_require_network: |
|
run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
|
if not run_network_tests: |
|
raise SkipTest("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0") |
|
|
|
try: |
|
import pooch |
|
except ImportError: |
|
raise SkipTest("test requires pooch to be installed") |
|
|
|
from scipy.datasets import face |
|
else: |
|
from scipy.misc import face |
|
|
|
return face(gray=True) |
|
|
|
|
|
dataset_fetchers = { |
|
"fetch_20newsgroups_fxt": fetch_20newsgroups, |
|
"fetch_20newsgroups_vectorized_fxt": fetch_20newsgroups_vectorized, |
|
"fetch_california_housing_fxt": fetch_california_housing, |
|
"fetch_covtype_fxt": fetch_covtype, |
|
"fetch_kddcup99_fxt": fetch_kddcup99, |
|
"fetch_lfw_pairs_fxt": fetch_lfw_pairs, |
|
"fetch_lfw_people_fxt": fetch_lfw_people, |
|
"fetch_olivetti_faces_fxt": fetch_olivetti_faces, |
|
"fetch_rcv1_fxt": fetch_rcv1, |
|
"fetch_species_distributions_fxt": fetch_species_distributions, |
|
} |
|
|
|
if scipy_datasets_require_network: |
|
dataset_fetchers["raccoon_face_fxt"] = raccoon_face_or_skip |
|
|
|
_SKIP32_MARK = pytest.mark.skipif( |
|
environ.get("SKLEARN_RUN_FLOAT32_TESTS", "0") != "1", |
|
reason="Set SKLEARN_RUN_FLOAT32_TESTS=1 to run float32 dtype tests", |
|
) |
|
|
|
|
|
|
|
@pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64]) |
|
def global_dtype(request): |
|
yield request.param |
|
|
|
|
|
def _fetch_fixture(f): |
|
"""Fetch dataset (download if missing and requested by environment).""" |
|
download_if_missing = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
|
|
|
@wraps(f) |
|
def wrapped(*args, **kwargs): |
|
kwargs["download_if_missing"] = download_if_missing |
|
try: |
|
return f(*args, **kwargs) |
|
except OSError as e: |
|
if str(e) != "Data not found and `download_if_missing` is False": |
|
raise |
|
pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0") |
|
|
|
return pytest.fixture(lambda: wrapped) |
|
|
|
|
|
|
|
fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups) |
|
fetch_20newsgroups_vectorized_fxt = _fetch_fixture(fetch_20newsgroups_vectorized) |
|
fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing) |
|
fetch_covtype_fxt = _fetch_fixture(fetch_covtype) |
|
fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99) |
|
fetch_lfw_pairs_fxt = _fetch_fixture(fetch_lfw_pairs) |
|
fetch_lfw_people_fxt = _fetch_fixture(fetch_lfw_people) |
|
fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces) |
|
fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1) |
|
fetch_species_distributions_fxt = _fetch_fixture(fetch_species_distributions) |
|
raccoon_face_fxt = pytest.fixture(raccoon_face_or_skip) |
|
|
|
|
|
def pytest_collection_modifyitems(config, items): |
|
"""Called after collect is completed. |
|
|
|
Parameters |
|
---------- |
|
config : pytest config |
|
items : list of collected items |
|
""" |
|
run_network_tests = environ.get("SKLEARN_SKIP_NETWORK_TESTS", "1") == "0" |
|
skip_network = pytest.mark.skip( |
|
reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0" |
|
) |
|
|
|
|
|
|
|
dataset_features_set = set(dataset_fetchers) |
|
datasets_to_download = set() |
|
|
|
for item in items: |
|
if isinstance(item, DoctestItem) and "fetch_" in item.name: |
|
fetcher_function_name = item.name.split(".")[-1] |
|
dataset_fetchers_key = f"{fetcher_function_name}_fxt" |
|
dataset_to_fetch = set([dataset_fetchers_key]) & dataset_features_set |
|
elif not hasattr(item, "fixturenames"): |
|
continue |
|
else: |
|
item_fixtures = set(item.fixturenames) |
|
dataset_to_fetch = item_fixtures & dataset_features_set |
|
|
|
if not dataset_to_fetch: |
|
continue |
|
|
|
if run_network_tests: |
|
datasets_to_download |= dataset_to_fetch |
|
else: |
|
|
|
item.add_marker(skip_network) |
|
|
|
|
|
|
|
|
|
worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0") |
|
if worker_id == "gw0" and run_network_tests: |
|
for name in datasets_to_download: |
|
with suppress(SkipTest): |
|
dataset_fetchers[name]() |
|
|
|
for item in items: |
|
|
|
if ( |
|
item.name.endswith("GradientBoostingClassifier") |
|
and platform.machine() == "aarch64" |
|
): |
|
marker = pytest.mark.xfail( |
|
reason=( |
|
"know failure. See " |
|
"https://github.com/scikit-learn/scikit-learn/issues/17797" |
|
) |
|
) |
|
item.add_marker(marker) |
|
|
|
skip_doctests = False |
|
try: |
|
import matplotlib |
|
except ImportError: |
|
skip_doctests = True |
|
reason = "matplotlib is required to run the doctests" |
|
|
|
if _IS_32BIT: |
|
reason = "doctest are only run when the default numpy int is 64 bits." |
|
skip_doctests = True |
|
elif sys.platform.startswith("win32"): |
|
reason = ( |
|
"doctests are not run for Windows because numpy arrays " |
|
"repr is inconsistent across platforms." |
|
) |
|
skip_doctests = True |
|
|
|
if np_base_version < parse_version("2"): |
|
|
|
|
|
|
|
reason = "Due to NEP 51 numpy scalar repr has changed in numpy 2" |
|
skip_doctests = True |
|
|
|
if sp_version < parse_version("1.14"): |
|
reason = "Scipy sparse matrix repr has changed in scipy 1.14" |
|
skip_doctests = True |
|
|
|
|
|
|
|
|
|
for item in items: |
|
if isinstance(item, DoctestItem): |
|
item.dtest.globs = {} |
|
|
|
if skip_doctests: |
|
skip_marker = pytest.mark.skip(reason=reason) |
|
|
|
for item in items: |
|
if isinstance(item, DoctestItem): |
|
|
|
|
|
|
|
|
|
if item.name != "sklearn._config.config_context": |
|
item.add_marker(skip_marker) |
|
try: |
|
import PIL |
|
|
|
pillow_installed = True |
|
except ImportError: |
|
pillow_installed = False |
|
|
|
if not pillow_installed: |
|
skip_marker = pytest.mark.skip(reason="pillow (or PIL) not installed!") |
|
for item in items: |
|
if item.name in [ |
|
"sklearn.feature_extraction.image.PatchExtractor", |
|
"sklearn.feature_extraction.image.extract_patches_2d", |
|
]: |
|
item.add_marker(skip_marker) |
|
|
|
|
|
@pytest.fixture(scope="function") |
|
def pyplot(): |
|
"""Setup and teardown fixture for matplotlib. |
|
|
|
This fixture checks if we can import matplotlib. If not, the tests will be |
|
skipped. Otherwise, we close the figures before and after running the |
|
functions. |
|
|
|
Returns |
|
------- |
|
pyplot : module |
|
The ``matplotlib.pyplot`` module. |
|
""" |
|
pyplot = pytest.importorskip("matplotlib.pyplot") |
|
pyplot.close("all") |
|
yield pyplot |
|
pyplot.close("all") |
|
|
|
|
|
def pytest_generate_tests(metafunc): |
|
"""Parametrization of global_random_seed fixture |
|
|
|
based on the SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable. |
|
|
|
The goal of this fixture is to prevent tests that use it to be sensitive |
|
to a specific seed value while still being deterministic by default. |
|
|
|
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED |
|
variable for instructions on how to use this fixture. |
|
|
|
https://scikit-learn.org/dev/computing/parallelism.html#sklearn-tests-global-random-seed |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
RANDOM_SEED_RANGE = list(range(100)) |
|
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED") |
|
|
|
default_random_seeds = [42] |
|
|
|
if random_seed_var is None: |
|
random_seeds = default_random_seeds |
|
elif random_seed_var == "all": |
|
random_seeds = RANDOM_SEED_RANGE |
|
else: |
|
if "-" in random_seed_var: |
|
start, stop = random_seed_var.split("-") |
|
random_seeds = list(range(int(start), int(stop) + 1)) |
|
else: |
|
random_seeds = [int(random_seed_var)] |
|
|
|
if min(random_seeds) < 0 or max(random_seeds) > 99: |
|
raise ValueError( |
|
"The value(s) of the environment variable " |
|
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] " |
|
f"(or 'all'), got: {random_seed_var}" |
|
) |
|
|
|
if "global_random_seed" in metafunc.fixturenames: |
|
metafunc.parametrize("global_random_seed", random_seeds) |
|
|
|
|
|
def pytest_configure(config): |
|
|
|
try: |
|
import matplotlib |
|
|
|
matplotlib.use("agg") |
|
except ImportError: |
|
pass |
|
|
|
allowed_parallelism = joblib.cpu_count(only_physical_cores=True) |
|
xdist_worker_count = environ.get("PYTEST_XDIST_WORKER_COUNT") |
|
if xdist_worker_count is not None: |
|
|
|
|
|
allowed_parallelism = max(allowed_parallelism // int(xdist_worker_count), 1) |
|
threadpool_limits(allowed_parallelism) |
|
|
|
if environ.get("SKLEARN_WARNINGS_AS_ERRORS", "0") != "0": |
|
|
|
|
|
|
|
for line in get_pytest_filterwarning_lines(): |
|
config.addinivalue_line("filterwarnings", line) |
|
|
|
|
|
@pytest.fixture |
|
def hide_available_pandas(monkeypatch): |
|
"""Pretend pandas was not installed.""" |
|
import_orig = builtins.__import__ |
|
|
|
def mocked_import(name, *args, **kwargs): |
|
if name == "pandas": |
|
raise ImportError() |
|
return import_orig(name, *args, **kwargs) |
|
|
|
monkeypatch.setattr(builtins, "__import__", mocked_import) |
|
|
|
|
|
@pytest.fixture |
|
def print_changed_only_false(): |
|
"""Set `print_changed_only` to False for the duration of the test.""" |
|
set_config(print_changed_only=False) |
|
yield |
|
set_config(print_changed_only=True) |
|
|