|
|
|
|
|
|
|
import re |
|
|
|
import numpy as np |
|
import pytest |
|
from numpy.testing import assert_array_equal |
|
|
|
from sklearn import config_context |
|
from sklearn.base import ( |
|
BaseEstimator, |
|
clone, |
|
is_classifier, |
|
is_clusterer, |
|
is_outlier_detector, |
|
is_regressor, |
|
) |
|
from sklearn.cluster import KMeans |
|
from sklearn.compose import make_column_transformer |
|
from sklearn.datasets import make_classification, make_regression |
|
from sklearn.exceptions import NotFittedError, UnsetMetadataPassedError |
|
from sklearn.frozen import FrozenEstimator |
|
from sklearn.linear_model import LinearRegression, LogisticRegression |
|
from sklearn.neighbors import LocalOutlierFactor |
|
from sklearn.pipeline import make_pipeline |
|
from sklearn.preprocessing import RobustScaler, StandardScaler |
|
from sklearn.utils._testing import set_random_state |
|
from sklearn.utils.validation import check_is_fitted |
|
|
|
|
|
@pytest.fixture |
|
def regression_dataset(): |
|
return make_regression() |
|
|
|
|
|
@pytest.fixture |
|
def classification_dataset(): |
|
return make_classification() |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"estimator, dataset", |
|
[ |
|
(LinearRegression(), "regression_dataset"), |
|
(LogisticRegression(), "classification_dataset"), |
|
(make_pipeline(StandardScaler(), LinearRegression()), "regression_dataset"), |
|
( |
|
make_pipeline(StandardScaler(), LogisticRegression()), |
|
"classification_dataset", |
|
), |
|
(StandardScaler(), "regression_dataset"), |
|
(KMeans(), "regression_dataset"), |
|
(LocalOutlierFactor(), "regression_dataset"), |
|
( |
|
make_column_transformer( |
|
(StandardScaler(), [0]), |
|
(RobustScaler(), [1]), |
|
), |
|
"regression_dataset", |
|
), |
|
], |
|
) |
|
@pytest.mark.parametrize( |
|
"method", |
|
["predict", "predict_proba", "predict_log_proba", "decision_function", "transform"], |
|
) |
|
def test_frozen_methods(estimator, dataset, request, method): |
|
"""Test that frozen.fit doesn't do anything, and that all other methods are |
|
exposed by the frozen estimator and return the same values as the estimator. |
|
""" |
|
X, y = request.getfixturevalue(dataset) |
|
set_random_state(estimator) |
|
estimator.fit(X, y) |
|
frozen = FrozenEstimator(estimator) |
|
|
|
frozen.fit([[1]], [1]) |
|
|
|
if hasattr(estimator, method): |
|
assert_array_equal(getattr(estimator, method)(X), getattr(frozen, method)(X)) |
|
|
|
assert is_classifier(estimator) == is_classifier(frozen) |
|
assert is_regressor(estimator) == is_regressor(frozen) |
|
assert is_clusterer(estimator) == is_clusterer(frozen) |
|
assert is_outlier_detector(estimator) == is_outlier_detector(frozen) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_frozen_metadata_routing(regression_dataset): |
|
"""Test that metadata routing works with frozen estimators.""" |
|
|
|
class ConsumesMetadata(BaseEstimator): |
|
def __init__(self, on_fit=None, on_predict=None): |
|
self.on_fit = on_fit |
|
self.on_predict = on_predict |
|
|
|
def fit(self, X, y, metadata=None): |
|
if self.on_fit: |
|
assert metadata is not None |
|
self.fitted_ = True |
|
return self |
|
|
|
def predict(self, X, metadata=None): |
|
if self.on_predict: |
|
assert metadata is not None |
|
return np.ones(len(X)) |
|
|
|
X, y = regression_dataset |
|
pipeline = make_pipeline( |
|
ConsumesMetadata(on_fit=True, on_predict=True) |
|
.set_fit_request(metadata=True) |
|
.set_predict_request(metadata=True) |
|
) |
|
|
|
pipeline.fit(X, y, metadata="test") |
|
frozen = FrozenEstimator(pipeline) |
|
pipeline.predict(X, metadata="test") |
|
frozen.predict(X, metadata="test") |
|
|
|
frozen["consumesmetadata"].set_predict_request(metadata=False) |
|
with pytest.raises( |
|
TypeError, |
|
match=re.escape( |
|
"Pipeline.predict got unexpected argument(s) {'metadata'}, which are not " |
|
"routed to any object." |
|
), |
|
): |
|
frozen.predict(X, metadata="test") |
|
|
|
frozen["consumesmetadata"].set_predict_request(metadata=None) |
|
with pytest.raises(UnsetMetadataPassedError): |
|
frozen.predict(X, metadata="test") |
|
|
|
|
|
def test_composite_fit(classification_dataset): |
|
"""Test that calling fit_transform and fit_predict doesn't call fit.""" |
|
|
|
class Estimator(BaseEstimator): |
|
def fit(self, X, y): |
|
try: |
|
self._fit_counter += 1 |
|
except AttributeError: |
|
self._fit_counter = 1 |
|
return self |
|
|
|
def fit_transform(self, X, y=None): |
|
|
|
... |
|
|
|
def fit_predict(self, X, y=None): |
|
|
|
... |
|
|
|
X, y = classification_dataset |
|
est = Estimator().fit(X, y) |
|
frozen = FrozenEstimator(est) |
|
|
|
with pytest.raises(AttributeError): |
|
frozen.fit_predict(X, y) |
|
with pytest.raises(AttributeError): |
|
frozen.fit_transform(X, y) |
|
|
|
assert frozen._fit_counter == 1 |
|
|
|
|
|
def test_clone_frozen(regression_dataset): |
|
"""Test that cloning a frozen estimator keeps the frozen state.""" |
|
X, y = regression_dataset |
|
estimator = LinearRegression().fit(X, y) |
|
frozen = FrozenEstimator(estimator) |
|
cloned = clone(frozen) |
|
assert cloned.estimator is estimator |
|
|
|
|
|
def test_check_is_fitted(regression_dataset): |
|
"""Test that check_is_fitted works on frozen estimators.""" |
|
X, y = regression_dataset |
|
|
|
estimator = LinearRegression() |
|
frozen = FrozenEstimator(estimator) |
|
with pytest.raises(NotFittedError): |
|
check_is_fitted(frozen) |
|
|
|
estimator = LinearRegression().fit(X, y) |
|
frozen = FrozenEstimator(estimator) |
|
check_is_fitted(frozen) |
|
|
|
|
|
def test_frozen_tags(): |
|
"""Test that frozen estimators have the same tags as the original estimator |
|
except for the skip_test tag.""" |
|
|
|
class Estimator(BaseEstimator): |
|
def __sklearn_tags__(self): |
|
tags = super().__sklearn_tags__() |
|
tags.input_tags.categorical = True |
|
return tags |
|
|
|
estimator = Estimator() |
|
frozen = FrozenEstimator(estimator) |
|
frozen_tags = frozen.__sklearn_tags__() |
|
estimator_tags = estimator.__sklearn_tags__() |
|
|
|
assert frozen_tags._skip_test is True |
|
assert estimator_tags._skip_test is False |
|
|
|
assert estimator_tags.input_tags.categorical is True |
|
assert frozen_tags.input_tags.categorical is True |
|
|
|
|
|
def test_frozen_params(): |
|
"""Test that FrozenEstimator only exposes the estimator parameter.""" |
|
est = LogisticRegression() |
|
frozen = FrozenEstimator(est) |
|
|
|
with pytest.raises(ValueError, match="You cannot set parameters of the inner"): |
|
frozen.set_params(estimator__C=1) |
|
|
|
assert frozen.get_params() == {"estimator": est} |
|
|
|
other_est = LocalOutlierFactor() |
|
frozen.set_params(estimator=other_est) |
|
assert frozen.get_params() == {"estimator": other_est} |
|
|