File size: 13,573 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 |
"""Global configuration state and functions for management"""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import os
import threading
from contextlib import contextmanager as contextmanager
_global_config = {
"assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
"working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
"print_changed_only": True,
"display": "diagram",
"pairwise_dist_chunk_size": int(
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
),
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
"transform_output": "default",
"enable_metadata_routing": False,
"skip_parameter_validation": False,
}
_threadlocal = threading.local()
def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration. If the configuration
does not exist, copy the default global configuration."""
if not hasattr(_threadlocal, "global_config"):
_threadlocal.global_config = _global_config.copy()
return _threadlocal.global_config
def get_config():
"""Retrieve current values for configuration set by :func:`set_config`.
Returns
-------
config : dict
Keys are parameter names that can be passed to :func:`set_config`.
See Also
--------
config_context : Context manager for global scikit-learn configuration.
set_config : Set global scikit-learn configuration.
Examples
--------
>>> import sklearn
>>> config = sklearn.get_config()
>>> config.keys()
dict_keys([...])
"""
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()
def set_config(
assume_finite=None,
working_memory=None,
print_changed_only=None,
display=None,
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
enable_metadata_routing=None,
skip_parameter_validation=None,
):
"""Set global scikit-learn configuration.
.. versionadded:: 0.19
Parameters
----------
assume_finite : bool, default=None
If True, validation for finiteness will be skipped,
saving time, but leading to potential crashes. If
False, validation for finiteness will be performed,
avoiding error. Global default: False.
.. versionadded:: 0.19
working_memory : int, default=None
If set, scikit-learn will attempt to limit the size of temporary arrays
to this number of MiB (per job when parallelised), often saving both
computation time and memory on expensive operations that can be
performed in chunks. Global default: 1024.
.. versionadded:: 0.20
print_changed_only : bool, default=None
If True, only the parameters that were set to non-default
values will be printed when printing an estimator. For example,
``print(SVC())`` while True will only print 'SVC()' while the default
behaviour would be to print 'SVC(C=1.0, cache_size=200, ...)' with
all the non-changed parameters.
.. versionadded:: 0.21
display : {'text', 'diagram'}, default=None
If 'diagram', estimators will be displayed as a diagram in a Jupyter
lab or notebook context. If 'text', estimators will be displayed as
text. Default is 'diagram'.
.. versionadded:: 0.23
pairwise_dist_chunk_size : int, default=None
The number of row vectors per chunk for the accelerated pairwise-
distances reduction backend. Default is 256 (suitable for most of
modern laptops' caches and architectures).
Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.
.. versionadded:: 1.1
enable_cython_pairwise_dist : bool, default=None
Use the accelerated pairwise-distances reduction backend when
possible. Global default: True.
Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.
.. versionadded:: 1.1
array_api_dispatch : bool, default=None
Use Array API dispatching when inputs follow the Array API standard.
Default is False.
See the :ref:`User Guide <array_api>` for more details.
.. versionadded:: 1.2
transform_output : str, default=None
Configure output of `transform` and `fit_transform`.
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
for an example on how to use the API.
- `"default"`: Default output format of a transformer
- `"pandas"`: DataFrame output
- `"polars"`: Polars output
- `None`: Transform configuration is unchanged
.. versionadded:: 1.2
.. versionadded:: 1.4
`"polars"` option was added.
enable_metadata_routing : bool, default=None
Enable metadata routing. By default this feature is disabled.
Refer to :ref:`metadata routing user guide <metadata_routing>` for more
details.
- `True`: Metadata routing is enabled
- `False`: Metadata routing is disabled, use the old syntax.
- `None`: Configuration is unchanged
.. versionadded:: 1.3
skip_parameter_validation : bool, default=None
If `True`, disable the validation of the hyper-parameters' types and values in
the fit method of estimators and for arguments passed to public helper
functions. It can save time in some situations but can lead to low level
crashes and exceptions with confusing error messages.
Note that for data parameters, such as `X` and `y`, only type validation is
skipped but validation with `check_array` will continue to run.
.. versionadded:: 1.3
See Also
--------
config_context : Context manager for global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
Examples
--------
>>> from sklearn import set_config
>>> set_config(display='diagram') # doctest: +SKIP
"""
local_config = _get_threadlocal_config()
if assume_finite is not None:
local_config["assume_finite"] = assume_finite
if working_memory is not None:
local_config["working_memory"] = working_memory
if print_changed_only is not None:
local_config["print_changed_only"] = print_changed_only
if display is not None:
local_config["display"] = display
if pairwise_dist_chunk_size is not None:
local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
if enable_cython_pairwise_dist is not None:
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
if array_api_dispatch is not None:
from .utils._array_api import _check_array_api_dispatch
_check_array_api_dispatch(array_api_dispatch)
local_config["array_api_dispatch"] = array_api_dispatch
if transform_output is not None:
local_config["transform_output"] = transform_output
if enable_metadata_routing is not None:
local_config["enable_metadata_routing"] = enable_metadata_routing
if skip_parameter_validation is not None:
local_config["skip_parameter_validation"] = skip_parameter_validation
@contextmanager
def config_context(
*,
assume_finite=None,
working_memory=None,
print_changed_only=None,
display=None,
pairwise_dist_chunk_size=None,
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
enable_metadata_routing=None,
skip_parameter_validation=None,
):
"""Context manager for global scikit-learn configuration.
Parameters
----------
assume_finite : bool, default=None
If True, validation for finiteness will be skipped,
saving time, but leading to potential crashes. If
False, validation for finiteness will be performed,
avoiding error. If None, the existing value won't change.
The default value is False.
working_memory : int, default=None
If set, scikit-learn will attempt to limit the size of temporary arrays
to this number of MiB (per job when parallelised), often saving both
computation time and memory on expensive operations that can be
performed in chunks. If None, the existing value won't change.
The default value is 1024.
print_changed_only : bool, default=None
If True, only the parameters that were set to non-default
values will be printed when printing an estimator. For example,
``print(SVC())`` while True will only print 'SVC()', but would print
'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters
when False. If None, the existing value won't change.
The default value is True.
.. versionchanged:: 0.23
Default changed from False to True.
display : {'text', 'diagram'}, default=None
If 'diagram', estimators will be displayed as a diagram in a Jupyter
lab or notebook context. If 'text', estimators will be displayed as
text. If None, the existing value won't change.
The default value is 'diagram'.
.. versionadded:: 0.23
pairwise_dist_chunk_size : int, default=None
The number of row vectors per chunk for the accelerated pairwise-
distances reduction backend. Default is 256 (suitable for most of
modern laptops' caches and architectures).
Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.
.. versionadded:: 1.1
enable_cython_pairwise_dist : bool, default=None
Use the accelerated pairwise-distances reduction backend when
possible. Global default: True.
Intended for easier benchmarking and testing of scikit-learn internals.
End users are not expected to benefit from customizing this configuration
setting.
.. versionadded:: 1.1
array_api_dispatch : bool, default=None
Use Array API dispatching when inputs follow the Array API standard.
Default is False.
See the :ref:`User Guide <array_api>` for more details.
.. versionadded:: 1.2
transform_output : str, default=None
Configure output of `transform` and `fit_transform`.
See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
for an example on how to use the API.
- `"default"`: Default output format of a transformer
- `"pandas"`: DataFrame output
- `"polars"`: Polars output
- `None`: Transform configuration is unchanged
.. versionadded:: 1.2
.. versionadded:: 1.4
`"polars"` option was added.
enable_metadata_routing : bool, default=None
Enable metadata routing. By default this feature is disabled.
Refer to :ref:`metadata routing user guide <metadata_routing>` for more
details.
- `True`: Metadata routing is enabled
- `False`: Metadata routing is disabled, use the old syntax.
- `None`: Configuration is unchanged
.. versionadded:: 1.3
skip_parameter_validation : bool, default=None
If `True`, disable the validation of the hyper-parameters' types and values in
the fit method of estimators and for arguments passed to public helper
functions. It can save time in some situations but can lead to low level
crashes and exceptions with confusing error messages.
Note that for data parameters, such as `X` and `y`, only type validation is
skipped but validation with `check_array` will continue to run.
.. versionadded:: 1.3
Yields
------
None.
See Also
--------
set_config : Set global scikit-learn configuration.
get_config : Retrieve current values of the global configuration.
Notes
-----
All settings, not just those presently modified, will be returned to
their previous values when the context manager is exited.
Examples
--------
>>> import sklearn
>>> from sklearn.utils.validation import assert_all_finite
>>> with sklearn.config_context(assume_finite=True):
... assert_all_finite([float('nan')])
>>> with sklearn.config_context(assume_finite=True):
... with sklearn.config_context(assume_finite=False):
... assert_all_finite([float('nan')])
Traceback (most recent call last):
...
ValueError: Input contains NaN...
"""
old_config = get_config()
set_config(
assume_finite=assume_finite,
working_memory=working_memory,
print_changed_only=print_changed_only,
display=display,
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
transform_output=transform_output,
enable_metadata_routing=enable_metadata_routing,
skip_parameter_validation=skip_parameter_validation,
)
try:
yield
finally:
set_config(**old_config)
|