File size: 3,592 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""Various utilities to help with development."""
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
import platform
import warnings
from collections.abc import Sequence
import numpy as np
from ..exceptions import DataConversionWarning
from . import _joblib, metadata_routing
from ._bunch import Bunch
from ._chunking import gen_batches, gen_even_slices
from ._estimator_html_repr import estimator_html_repr
# Make _safe_indexing importable from here for backward compat as this particular
# helper is considered semi-private and typically very useful for third-party
# libraries that want to comply with scikit-learn's estimator API. In particular,
# _safe_indexing was included in our public API documentation despite the leading
# `_` in its name.
from ._indexing import (
_safe_indexing, # noqa
resample,
shuffle,
)
from ._mask import safe_mask
from ._tags import (
ClassifierTags,
InputTags,
RegressorTags,
Tags,
TargetTags,
TransformerTags,
get_tags,
)
from .class_weight import compute_class_weight, compute_sample_weight
from .deprecation import deprecated
from .discovery import all_estimators
from .extmath import safe_sqr
from .murmurhash import murmurhash3_32
from .validation import (
as_float_array,
assert_all_finite,
check_array,
check_consistent_length,
check_random_state,
check_scalar,
check_symmetric,
check_X_y,
column_or_1d,
indexable,
)
# TODO(1.7): remove parallel_backend and register_parallel_backend
msg = "deprecated in 1.5 to be removed in 1.7. Use joblib.{} instead."
register_parallel_backend = deprecated(msg)(_joblib.register_parallel_backend)
# if a class, deprecated will change the object in _joblib module so we need to subclass
@deprecated(msg)
class parallel_backend(_joblib.parallel_backend):
pass
__all__ = [
"murmurhash3_32",
"as_float_array",
"assert_all_finite",
"check_array",
"check_random_state",
"compute_class_weight",
"compute_sample_weight",
"column_or_1d",
"check_consistent_length",
"check_X_y",
"check_scalar",
"indexable",
"check_symmetric",
"deprecated",
"parallel_backend",
"register_parallel_backend",
"resample",
"shuffle",
"all_estimators",
"DataConversionWarning",
"estimator_html_repr",
"Bunch",
"metadata_routing",
"safe_sqr",
"safe_mask",
"gen_batches",
"gen_even_slices",
"Tags",
"InputTags",
"TargetTags",
"ClassifierTags",
"RegressorTags",
"TransformerTags",
"get_tags",
]
# TODO(1.7): remove
def __getattr__(name):
if name == "IS_PYPY":
warnings.warn(
"IS_PYPY is deprecated and will be removed in 1.7.",
FutureWarning,
)
return platform.python_implementation() == "PyPy"
raise AttributeError(f"module {__name__} has no attribute {name}")
# TODO(1.7): remove tosequence
@deprecated("tosequence was deprecated in 1.5 and will be removed in 1.7")
def tosequence(x):
"""Cast iterable x to a Sequence, avoiding a copy if possible.
Parameters
----------
x : iterable
The iterable to be converted.
Returns
-------
x : Sequence
If `x` is a NumPy array, it returns it as a `ndarray`. If `x`
is a `Sequence`, `x` is returned as-is. If `x` is from any other
type, `x` is returned casted as a list.
"""
if isinstance(x, np.ndarray):
return np.asarray(x)
elif isinstance(x, Sequence):
return x
else:
return list(x)
|