File size: 4,985 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause
from copy import deepcopy
from ..base import BaseEstimator
from ..exceptions import NotFittedError
from ..utils import get_tags
from ..utils.metaestimators import available_if
from ..utils.validation import check_is_fitted
def _estimator_has(attr):
"""Check that final_estimator has `attr`.
Used together with `available_if`.
"""
def check(self):
# raise original `AttributeError` if `attr` does not exist
getattr(self.estimator, attr)
return True
return check
class FrozenEstimator(BaseEstimator):
"""Estimator that wraps a fitted estimator to prevent re-fitting.
This meta-estimator takes an estimator and freezes it, in the sense that calling
`fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled.
All other methods are delegated to the original estimator and original estimator's
attributes are accessible as well.
This is particularly useful when you have a fitted or a pre-trained model as a
transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this
step.
Parameters
----------
estimator : estimator
The estimator which is to be kept frozen.
See Also
--------
None: No similar entry in the scikit-learn documentation.
Examples
--------
>>> from sklearn.datasets import make_classification
>>> from sklearn.frozen import FrozenEstimator
>>> from sklearn.linear_model import LogisticRegression
>>> X, y = make_classification(random_state=0)
>>> clf = LogisticRegression(random_state=0).fit(X, y)
>>> frozen_clf = FrozenEstimator(clf)
>>> frozen_clf.fit(X, y) # No-op
FrozenEstimator(estimator=LogisticRegression(random_state=0))
>>> frozen_clf.predict(X) # Predictions from `clf.predict`
array(...)
"""
def __init__(self, estimator):
self.estimator = estimator
@available_if(_estimator_has("__getitem__"))
def __getitem__(self, *args, **kwargs):
"""__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \
:class:`~sklearn.compose.ColumnTransformer`.
"""
return self.estimator.__getitem__(*args, **kwargs)
def __getattr__(self, name):
# `estimator`'s attributes are now accessible except `fit_predict` and
# `fit_transform`
if name in ["fit_predict", "fit_transform"]:
raise AttributeError(f"{name} is not available for frozen estimators.")
return getattr(self.estimator, name)
def __sklearn_clone__(self):
return self
def __sklearn_is_fitted__(self):
try:
check_is_fitted(self.estimator)
return True
except NotFittedError:
return False
def fit(self, X, y, *args, **kwargs):
"""No-op.
As a frozen estimator, calling `fit` has no effect.
Parameters
----------
X : object
Ignored.
y : object
Ignored.
*args : tuple
Additional positional arguments. Ignored, but present for API compatibility
with `self.estimator`.
**kwargs : dict
Additional keyword arguments. Ignored, but present for API compatibility
with `self.estimator`.
Returns
-------
self : object
Returns the instance itself.
"""
check_is_fitted(self.estimator)
return self
def set_params(self, **kwargs):
"""Set the parameters of this estimator.
The only valid key here is `estimator`. You cannot set the parameters of the
inner estimator.
Parameters
----------
**kwargs : dict
Estimator parameters.
Returns
-------
self : FrozenEstimator
This estimator.
"""
estimator = kwargs.pop("estimator", None)
if estimator is not None:
self.estimator = estimator
if kwargs:
raise ValueError(
"You cannot set parameters of the inner estimator in a frozen "
"estimator since calling `fit` has no effect. You can use "
"`frozenestimator.estimator.set_params` to set parameters of the inner "
"estimator."
)
def get_params(self, deep=True):
"""Get parameters for this estimator.
Returns a `{"estimator": estimator}` dict. The parameters of the inner
estimator are not included.
Parameters
----------
deep : bool, default=True
Ignored.
Returns
-------
params : dict
Parameter names mapped to their values.
"""
return {"estimator": self.estimator}
def __sklearn_tags__(self):
tags = deepcopy(get_tags(self.estimator))
tags._skip_test = True
return tags
|