|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
from ..base import BaseEstimator |
|
from ..exceptions import NotFittedError |
|
from ..utils import get_tags |
|
from ..utils.metaestimators import available_if |
|
from ..utils.validation import check_is_fitted |
|
|
|
|
|
def _estimator_has(attr): |
|
"""Check that final_estimator has `attr`. |
|
|
|
Used together with `available_if`. |
|
""" |
|
|
|
def check(self): |
|
|
|
getattr(self.estimator, attr) |
|
return True |
|
|
|
return check |
|
|
|
|
|
class FrozenEstimator(BaseEstimator): |
|
"""Estimator that wraps a fitted estimator to prevent re-fitting. |
|
|
|
This meta-estimator takes an estimator and freezes it, in the sense that calling |
|
`fit` on it has no effect. `fit_predict` and `fit_transform` are also disabled. |
|
All other methods are delegated to the original estimator and original estimator's |
|
attributes are accessible as well. |
|
|
|
This is particularly useful when you have a fitted or a pre-trained model as a |
|
transformer in a pipeline, and you'd like `pipeline.fit` to have no effect on this |
|
step. |
|
|
|
Parameters |
|
---------- |
|
estimator : estimator |
|
The estimator which is to be kept frozen. |
|
|
|
See Also |
|
-------- |
|
None: No similar entry in the scikit-learn documentation. |
|
|
|
Examples |
|
-------- |
|
>>> from sklearn.datasets import make_classification |
|
>>> from sklearn.frozen import FrozenEstimator |
|
>>> from sklearn.linear_model import LogisticRegression |
|
>>> X, y = make_classification(random_state=0) |
|
>>> clf = LogisticRegression(random_state=0).fit(X, y) |
|
>>> frozen_clf = FrozenEstimator(clf) |
|
>>> frozen_clf.fit(X, y) # No-op |
|
FrozenEstimator(estimator=LogisticRegression(random_state=0)) |
|
>>> frozen_clf.predict(X) # Predictions from `clf.predict` |
|
array(...) |
|
""" |
|
|
|
def __init__(self, estimator): |
|
self.estimator = estimator |
|
|
|
@available_if(_estimator_has("__getitem__")) |
|
def __getitem__(self, *args, **kwargs): |
|
"""__getitem__ is defined in :class:`~sklearn.pipeline.Pipeline` and \ |
|
:class:`~sklearn.compose.ColumnTransformer`. |
|
""" |
|
return self.estimator.__getitem__(*args, **kwargs) |
|
|
|
def __getattr__(self, name): |
|
|
|
|
|
if name in ["fit_predict", "fit_transform"]: |
|
raise AttributeError(f"{name} is not available for frozen estimators.") |
|
return getattr(self.estimator, name) |
|
|
|
def __sklearn_clone__(self): |
|
return self |
|
|
|
def __sklearn_is_fitted__(self): |
|
try: |
|
check_is_fitted(self.estimator) |
|
return True |
|
except NotFittedError: |
|
return False |
|
|
|
def fit(self, X, y, *args, **kwargs): |
|
"""No-op. |
|
|
|
As a frozen estimator, calling `fit` has no effect. |
|
|
|
Parameters |
|
---------- |
|
X : object |
|
Ignored. |
|
|
|
y : object |
|
Ignored. |
|
|
|
*args : tuple |
|
Additional positional arguments. Ignored, but present for API compatibility |
|
with `self.estimator`. |
|
|
|
**kwargs : dict |
|
Additional keyword arguments. Ignored, but present for API compatibility |
|
with `self.estimator`. |
|
|
|
Returns |
|
------- |
|
self : object |
|
Returns the instance itself. |
|
""" |
|
check_is_fitted(self.estimator) |
|
return self |
|
|
|
def set_params(self, **kwargs): |
|
"""Set the parameters of this estimator. |
|
|
|
The only valid key here is `estimator`. You cannot set the parameters of the |
|
inner estimator. |
|
|
|
Parameters |
|
---------- |
|
**kwargs : dict |
|
Estimator parameters. |
|
|
|
Returns |
|
------- |
|
self : FrozenEstimator |
|
This estimator. |
|
""" |
|
estimator = kwargs.pop("estimator", None) |
|
if estimator is not None: |
|
self.estimator = estimator |
|
if kwargs: |
|
raise ValueError( |
|
"You cannot set parameters of the inner estimator in a frozen " |
|
"estimator since calling `fit` has no effect. You can use " |
|
"`frozenestimator.estimator.set_params` to set parameters of the inner " |
|
"estimator." |
|
) |
|
|
|
def get_params(self, deep=True): |
|
"""Get parameters for this estimator. |
|
|
|
Returns a `{"estimator": estimator}` dict. The parameters of the inner |
|
estimator are not included. |
|
|
|
Parameters |
|
---------- |
|
deep : bool, default=True |
|
Ignored. |
|
|
|
Returns |
|
------- |
|
params : dict |
|
Parameter names mapped to their values. |
|
""" |
|
return {"estimator": self.estimator} |
|
|
|
def __sklearn_tags__(self): |
|
tags = deepcopy(get_tags(self.estimator)) |
|
tags._skip_test = True |
|
return tags |
|
|