|
import builtins |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import pytest |
|
|
|
import sklearn |
|
from sklearn import config_context, get_config, set_config |
|
from sklearn.utils.fixes import _IS_WASM |
|
from sklearn.utils.parallel import Parallel, delayed |
|
|
|
|
|
def test_config_context(): |
|
assert get_config() == { |
|
"assume_finite": False, |
|
"working_memory": 1024, |
|
"print_changed_only": True, |
|
"display": "diagram", |
|
"array_api_dispatch": False, |
|
"pairwise_dist_chunk_size": 256, |
|
"enable_cython_pairwise_dist": True, |
|
"transform_output": "default", |
|
"enable_metadata_routing": False, |
|
"skip_parameter_validation": False, |
|
} |
|
|
|
|
|
config_context(assume_finite=True) |
|
assert get_config()["assume_finite"] is False |
|
|
|
with config_context(assume_finite=True): |
|
assert get_config() == { |
|
"assume_finite": True, |
|
"working_memory": 1024, |
|
"print_changed_only": True, |
|
"display": "diagram", |
|
"array_api_dispatch": False, |
|
"pairwise_dist_chunk_size": 256, |
|
"enable_cython_pairwise_dist": True, |
|
"transform_output": "default", |
|
"enable_metadata_routing": False, |
|
"skip_parameter_validation": False, |
|
} |
|
assert get_config()["assume_finite"] is False |
|
|
|
with config_context(assume_finite=True): |
|
with config_context(assume_finite=None): |
|
assert get_config()["assume_finite"] is True |
|
|
|
assert get_config()["assume_finite"] is True |
|
|
|
with config_context(assume_finite=False): |
|
assert get_config()["assume_finite"] is False |
|
|
|
with config_context(assume_finite=None): |
|
assert get_config()["assume_finite"] is False |
|
|
|
|
|
|
|
set_config(assume_finite=True) |
|
assert get_config()["assume_finite"] is True |
|
|
|
assert get_config()["assume_finite"] is False |
|
|
|
assert get_config()["assume_finite"] is True |
|
|
|
assert get_config() == { |
|
"assume_finite": False, |
|
"working_memory": 1024, |
|
"print_changed_only": True, |
|
"display": "diagram", |
|
"array_api_dispatch": False, |
|
"pairwise_dist_chunk_size": 256, |
|
"enable_cython_pairwise_dist": True, |
|
"transform_output": "default", |
|
"enable_metadata_routing": False, |
|
"skip_parameter_validation": False, |
|
} |
|
|
|
|
|
with pytest.raises(TypeError): |
|
config_context(True) |
|
|
|
|
|
with pytest.raises(TypeError): |
|
config_context(do_something_else=True).__enter__() |
|
|
|
|
|
def test_config_context_exception(): |
|
assert get_config()["assume_finite"] is False |
|
try: |
|
with config_context(assume_finite=True): |
|
assert get_config()["assume_finite"] is True |
|
raise ValueError() |
|
except ValueError: |
|
pass |
|
assert get_config()["assume_finite"] is False |
|
|
|
|
|
def test_set_config(): |
|
assert get_config()["assume_finite"] is False |
|
set_config(assume_finite=None) |
|
assert get_config()["assume_finite"] is False |
|
set_config(assume_finite=True) |
|
assert get_config()["assume_finite"] is True |
|
set_config(assume_finite=None) |
|
assert get_config()["assume_finite"] is True |
|
set_config(assume_finite=False) |
|
assert get_config()["assume_finite"] is False |
|
|
|
|
|
with pytest.raises(TypeError): |
|
set_config(do_something_else=True) |
|
|
|
|
|
def set_assume_finite(assume_finite, sleep_duration): |
|
"""Return the value of assume_finite after waiting `sleep_duration`.""" |
|
with config_context(assume_finite=assume_finite): |
|
time.sleep(sleep_duration) |
|
return get_config()["assume_finite"] |
|
|
|
|
|
@pytest.mark.parametrize("backend", ["loky", "multiprocessing", "threading"]) |
|
def test_config_threadsafe_joblib(backend): |
|
"""Test that the global config is threadsafe with all joblib backends. |
|
Two jobs are spawned and sets assume_finite to two different values. |
|
When the job with a duration 0.1s completes, the assume_finite value |
|
should be the same as the value passed to the function. In other words, |
|
it is not influenced by the other job setting assume_finite to True. |
|
""" |
|
assume_finites = [False, True, False, True] |
|
sleep_durations = [0.1, 0.2, 0.1, 0.2] |
|
|
|
items = Parallel(backend=backend, n_jobs=2)( |
|
delayed(set_assume_finite)(assume_finite, sleep_dur) |
|
for assume_finite, sleep_dur in zip(assume_finites, sleep_durations) |
|
) |
|
|
|
assert items == [False, True, False, True] |
|
|
|
|
|
@pytest.mark.xfail(_IS_WASM, reason="cannot start threads") |
|
def test_config_threadsafe(): |
|
"""Uses threads directly to test that the global config does not change |
|
between threads. Same test as `test_config_threadsafe_joblib` but with |
|
`ThreadPoolExecutor`.""" |
|
|
|
assume_finites = [False, True, False, True] |
|
sleep_durations = [0.1, 0.2, 0.1, 0.2] |
|
|
|
with ThreadPoolExecutor(max_workers=2) as e: |
|
items = [ |
|
output |
|
for output in e.map(set_assume_finite, assume_finites, sleep_durations) |
|
] |
|
|
|
assert items == [False, True, False, True] |
|
|
|
|
|
def test_config_array_api_dispatch_error(monkeypatch): |
|
"""Check error is raised when array_api_compat is not installed.""" |
|
|
|
|
|
orig_import = builtins.__import__ |
|
|
|
def mocked_import(name, *args, **kwargs): |
|
if name == "array_api_compat": |
|
raise ImportError |
|
return orig_import(name, *args, **kwargs) |
|
|
|
monkeypatch.setattr(builtins, "__import__", mocked_import) |
|
|
|
with pytest.raises(ImportError, match="array_api_compat is required"): |
|
with config_context(array_api_dispatch=True): |
|
pass |
|
|
|
with pytest.raises(ImportError, match="array_api_compat is required"): |
|
set_config(array_api_dispatch=True) |
|
|
|
|
|
def test_config_array_api_dispatch_error_numpy(monkeypatch): |
|
"""Check error when NumPy is too old""" |
|
|
|
orig_import = builtins.__import__ |
|
|
|
def mocked_import(name, *args, **kwargs): |
|
if name == "array_api_compat": |
|
return object() |
|
return orig_import(name, *args, **kwargs) |
|
|
|
monkeypatch.setattr(builtins, "__import__", mocked_import) |
|
monkeypatch.setattr(sklearn.utils._array_api.numpy, "__version__", "1.20") |
|
|
|
with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"): |
|
with config_context(array_api_dispatch=True): |
|
pass |
|
|
|
with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"): |
|
set_config(array_api_dispatch=True) |
|
|