|
""" |
|
Metadata Routing Utility Tests |
|
""" |
|
|
|
|
|
|
|
|
|
import re |
|
|
|
import numpy as np |
|
import pytest |
|
|
|
from sklearn import config_context |
|
from sklearn.base import ( |
|
BaseEstimator, |
|
clone, |
|
) |
|
from sklearn.exceptions import UnsetMetadataPassedError |
|
from sklearn.linear_model import LinearRegression |
|
from sklearn.pipeline import Pipeline |
|
from sklearn.tests.metadata_routing_common import ( |
|
ConsumingClassifier, |
|
ConsumingRegressor, |
|
ConsumingTransformer, |
|
MetaRegressor, |
|
MetaTransformer, |
|
NonConsumingClassifier, |
|
WeightedMetaClassifier, |
|
WeightedMetaRegressor, |
|
_Registry, |
|
assert_request_equal, |
|
assert_request_is_empty, |
|
check_recorded_metadata, |
|
) |
|
from sklearn.utils import metadata_routing |
|
from sklearn.utils._metadata_requests import ( |
|
COMPOSITE_METHODS, |
|
METHODS, |
|
SIMPLE_METHODS, |
|
MethodMetadataRequest, |
|
MethodPair, |
|
_MetadataRequester, |
|
request_is_alias, |
|
request_is_valid, |
|
) |
|
from sklearn.utils.metadata_routing import ( |
|
MetadataRequest, |
|
MetadataRouter, |
|
MethodMapping, |
|
_RoutingNotSupportedMixin, |
|
get_routing_for_object, |
|
process_routing, |
|
) |
|
from sklearn.utils.validation import check_is_fitted |
|
|
|
rng = np.random.RandomState(42) |
|
N, M = 100, 4 |
|
X = rng.rand(N, M) |
|
y = rng.randint(0, 2, size=N) |
|
my_groups = rng.randint(0, 10, size=N) |
|
my_weights = rng.rand(N) |
|
my_other_weights = rng.rand(N) |
|
|
|
|
|
class SimplePipeline(BaseEstimator): |
|
"""A very simple pipeline, assuming the last step is always a predictor. |
|
|
|
Parameters |
|
---------- |
|
steps : iterable of objects |
|
An iterable of transformers with the last step being a predictor. |
|
""" |
|
|
|
def __init__(self, steps): |
|
self.steps = steps |
|
|
|
def fit(self, X, y, **fit_params): |
|
self.steps_ = [] |
|
params = process_routing(self, "fit", **fit_params) |
|
X_transformed = X |
|
for i, step in enumerate(self.steps[:-1]): |
|
transformer = clone(step).fit( |
|
X_transformed, y, **params.get(f"step_{i}").fit |
|
) |
|
self.steps_.append(transformer) |
|
X_transformed = transformer.transform( |
|
X_transformed, **params.get(f"step_{i}").transform |
|
) |
|
|
|
self.steps_.append( |
|
clone(self.steps[-1]).fit(X_transformed, y, **params.predictor.fit) |
|
) |
|
return self |
|
|
|
def predict(self, X, **predict_params): |
|
check_is_fitted(self) |
|
X_transformed = X |
|
params = process_routing(self, "predict", **predict_params) |
|
for i, step in enumerate(self.steps_[:-1]): |
|
X_transformed = step.transform(X, **params.get(f"step_{i}").transform) |
|
|
|
return self.steps_[-1].predict(X_transformed, **params.predictor.predict) |
|
|
|
def get_metadata_routing(self): |
|
router = MetadataRouter(owner=self.__class__.__name__) |
|
for i, step in enumerate(self.steps[:-1]): |
|
router.add( |
|
**{f"step_{i}": step}, |
|
method_mapping=MethodMapping() |
|
.add(caller="fit", callee="fit") |
|
.add(caller="fit", callee="transform") |
|
.add(caller="predict", callee="transform"), |
|
) |
|
router.add( |
|
predictor=self.steps[-1], |
|
method_mapping=MethodMapping() |
|
.add(caller="fit", callee="fit") |
|
.add(caller="predict", callee="predict"), |
|
) |
|
return router |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_assert_request_is_empty(): |
|
requests = MetadataRequest(owner="test") |
|
assert_request_is_empty(requests) |
|
|
|
requests.fit.add_request(param="foo", alias=None) |
|
|
|
assert_request_is_empty(requests) |
|
|
|
requests.fit.add_request(param="bar", alias="value") |
|
with pytest.raises(AssertionError): |
|
|
|
assert_request_is_empty(requests) |
|
|
|
|
|
assert_request_is_empty(requests, exclude="fit") |
|
|
|
requests.score.add_request(param="carrot", alias=True) |
|
with pytest.raises(AssertionError): |
|
|
|
assert_request_is_empty(requests, exclude="fit") |
|
|
|
|
|
assert_request_is_empty(requests, exclude=["fit", "score"]) |
|
|
|
|
|
assert_request_is_empty( |
|
MetadataRouter(owner="test") |
|
.add_self_request(WeightedMetaRegressor(estimator=None)) |
|
.add( |
|
estimator=ConsumingRegressor(), |
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
|
) |
|
) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"estimator", |
|
[ |
|
ConsumingClassifier(registry=_Registry()), |
|
ConsumingRegressor(registry=_Registry()), |
|
ConsumingTransformer(registry=_Registry()), |
|
WeightedMetaClassifier(estimator=ConsumingClassifier(), registry=_Registry()), |
|
WeightedMetaRegressor(estimator=ConsumingRegressor(), registry=_Registry()), |
|
], |
|
) |
|
@config_context(enable_metadata_routing=True) |
|
def test_estimator_puts_self_in_registry(estimator): |
|
"""Check that an estimator puts itself in the registry upon fit.""" |
|
estimator.fit(X, y) |
|
assert estimator in estimator.registry |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"val, res", |
|
[ |
|
(False, False), |
|
(True, False), |
|
(None, False), |
|
("$UNUSED$", False), |
|
("$WARN$", False), |
|
("invalid-input", False), |
|
("valid_arg", True), |
|
], |
|
) |
|
@config_context(enable_metadata_routing=True) |
|
def test_request_type_is_alias(val, res): |
|
|
|
assert request_is_alias(val) == res |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"val, res", |
|
[ |
|
(False, True), |
|
(True, True), |
|
(None, True), |
|
("$UNUSED$", True), |
|
("$WARN$", True), |
|
("invalid-input", False), |
|
("alias_arg", False), |
|
], |
|
) |
|
@config_context(enable_metadata_routing=True) |
|
def test_request_type_is_valid(val, res): |
|
|
|
assert request_is_valid(val) == res |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_default_requests(): |
|
class OddEstimator(BaseEstimator): |
|
__metadata_request__fit = { |
|
|
|
"sample_weight": True |
|
} |
|
|
|
odd_request = get_routing_for_object(OddEstimator()) |
|
assert odd_request.fit.requests == {"sample_weight": True} |
|
|
|
|
|
assert not len(get_routing_for_object(NonConsumingClassifier()).fit.requests) |
|
assert_request_is_empty(NonConsumingClassifier().get_metadata_routing()) |
|
|
|
trs_request = get_routing_for_object(ConsumingTransformer()) |
|
assert trs_request.fit.requests == { |
|
"sample_weight": None, |
|
"metadata": None, |
|
} |
|
assert trs_request.transform.requests == {"metadata": None, "sample_weight": None} |
|
assert_request_is_empty(trs_request) |
|
|
|
est_request = get_routing_for_object(ConsumingClassifier()) |
|
assert est_request.fit.requests == { |
|
"sample_weight": None, |
|
"metadata": None, |
|
} |
|
assert_request_is_empty(est_request) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_default_request_override(): |
|
"""Test that default requests are correctly overridden regardless of the ASCII order |
|
of the class names, hence testing small and capital letter class name starts. |
|
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28430 |
|
""" |
|
|
|
class Base(BaseEstimator): |
|
__metadata_request__split = {"groups": True} |
|
|
|
class class_1(Base): |
|
__metadata_request__split = {"groups": "sample_domain"} |
|
|
|
class Class_1(Base): |
|
__metadata_request__split = {"groups": "sample_domain"} |
|
|
|
assert_request_equal( |
|
class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}} |
|
) |
|
assert_request_equal( |
|
Class_1()._get_metadata_request(), {"split": {"groups": "sample_domain"}} |
|
) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_process_routing_invalid_method(): |
|
with pytest.raises(TypeError, match="Can only route and process input"): |
|
process_routing(ConsumingClassifier(), "invalid_method", groups=my_groups) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_process_routing_invalid_object(): |
|
class InvalidObject: |
|
pass |
|
|
|
with pytest.raises(AttributeError, match="either implement the routing method"): |
|
process_routing(InvalidObject(), "fit", groups=my_groups) |
|
|
|
|
|
@pytest.mark.parametrize("method", METHODS) |
|
@pytest.mark.parametrize("default", [None, "default", []]) |
|
@config_context(enable_metadata_routing=True) |
|
def test_process_routing_empty_params_get_with_default(method, default): |
|
empty_params = {} |
|
routed_params = process_routing(ConsumingClassifier(), "fit", **empty_params) |
|
|
|
|
|
params_for_method = routed_params[method] |
|
assert isinstance(params_for_method, dict) |
|
assert set(params_for_method.keys()) == set(METHODS) |
|
|
|
|
|
default_params_for_method = routed_params.get(method, default=default) |
|
assert default_params_for_method == params_for_method |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_simple_metadata_routing(): |
|
|
|
|
|
|
|
clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) |
|
clf.fit(X, y) |
|
|
|
|
|
|
|
clf = WeightedMetaClassifier(estimator=NonConsumingClassifier()) |
|
clf.fit(X, y, sample_weight=my_weights) |
|
|
|
|
|
|
|
clf = WeightedMetaClassifier(estimator=ConsumingClassifier()) |
|
err_message = ( |
|
"[sample_weight] are passed but are not explicitly set as requested or" |
|
" not requested for ConsumingClassifier.fit" |
|
) |
|
with pytest.raises(ValueError, match=re.escape(err_message)): |
|
clf.fit(X, y, sample_weight=my_weights) |
|
|
|
|
|
|
|
|
|
|
|
clf = WeightedMetaClassifier( |
|
estimator=ConsumingClassifier().set_fit_request(sample_weight=False) |
|
) |
|
|
|
|
|
|
|
clf.fit(X, y, sample_weight=my_weights) |
|
check_recorded_metadata(clf.estimator_, method="fit", parent="fit") |
|
|
|
|
|
clf = WeightedMetaClassifier( |
|
estimator=ConsumingClassifier().set_fit_request(sample_weight=True) |
|
) |
|
clf.fit(X, y, sample_weight=my_weights) |
|
check_recorded_metadata( |
|
clf.estimator_, method="fit", parent="fit", sample_weight=my_weights |
|
) |
|
|
|
|
|
clf = WeightedMetaClassifier( |
|
estimator=ConsumingClassifier().set_fit_request( |
|
sample_weight="alternative_weight" |
|
) |
|
) |
|
clf.fit(X, y, alternative_weight=my_weights) |
|
check_recorded_metadata( |
|
clf.estimator_, method="fit", parent="fit", sample_weight=my_weights |
|
) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_nested_routing(): |
|
|
|
pipeline = SimplePipeline( |
|
[ |
|
MetaTransformer( |
|
transformer=ConsumingTransformer() |
|
.set_fit_request(metadata=True, sample_weight=False) |
|
.set_transform_request(sample_weight=True, metadata=False) |
|
), |
|
WeightedMetaRegressor( |
|
estimator=ConsumingRegressor() |
|
.set_fit_request(sample_weight="inner_weights", metadata=False) |
|
.set_predict_request(sample_weight=False) |
|
).set_fit_request(sample_weight="outer_weights"), |
|
] |
|
) |
|
w1, w2, w3 = [1], [2], [3] |
|
pipeline.fit( |
|
X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3 |
|
) |
|
check_recorded_metadata( |
|
pipeline.steps_[0].transformer_, |
|
method="fit", |
|
parent="fit", |
|
metadata=my_groups, |
|
) |
|
check_recorded_metadata( |
|
pipeline.steps_[0].transformer_, |
|
method="transform", |
|
parent="fit", |
|
sample_weight=w1, |
|
) |
|
check_recorded_metadata( |
|
pipeline.steps_[1], method="fit", parent="fit", sample_weight=w2 |
|
) |
|
check_recorded_metadata( |
|
pipeline.steps_[1].estimator_, method="fit", parent="fit", sample_weight=w3 |
|
) |
|
|
|
pipeline.predict(X, sample_weight=w3) |
|
check_recorded_metadata( |
|
pipeline.steps_[0].transformer_, |
|
method="transform", |
|
parent="fit", |
|
sample_weight=w3, |
|
) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_nested_routing_conflict(): |
|
|
|
pipeline = SimplePipeline( |
|
[ |
|
MetaTransformer( |
|
transformer=ConsumingTransformer() |
|
.set_fit_request(metadata=True, sample_weight=False) |
|
.set_transform_request(sample_weight=True) |
|
), |
|
WeightedMetaRegressor( |
|
estimator=ConsumingRegressor().set_fit_request(sample_weight=True) |
|
).set_fit_request(sample_weight="outer_weights"), |
|
] |
|
) |
|
w1, w2 = [1], [2] |
|
with pytest.raises( |
|
ValueError, |
|
match=( |
|
re.escape( |
|
"In WeightedMetaRegressor, there is a conflict on sample_weight between" |
|
" what is requested for this estimator and what is requested by its" |
|
" children. You can resolve this conflict by using an alias for the" |
|
" child estimator(s) requested metadata." |
|
) |
|
), |
|
): |
|
pipeline.fit(X, y, metadata=my_groups, sample_weight=w1, outer_weights=w2) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_invalid_metadata(): |
|
|
|
trs = MetaTransformer( |
|
transformer=ConsumingTransformer().set_transform_request(sample_weight=True) |
|
) |
|
with pytest.raises( |
|
TypeError, |
|
match=(re.escape("transform got unexpected argument(s) {'other_param'}")), |
|
): |
|
trs.fit(X, y).transform(X, other_param=my_weights) |
|
|
|
|
|
trs = MetaTransformer( |
|
transformer=ConsumingTransformer().set_transform_request(sample_weight=False) |
|
) |
|
with pytest.raises( |
|
TypeError, |
|
match=(re.escape("transform got unexpected argument(s) {'sample_weight'}")), |
|
): |
|
trs.fit(X, y).transform(X, sample_weight=my_weights) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_get_metadata_routing(): |
|
class TestDefaultsBadMethodName(_MetadataRequester): |
|
__metadata_request__fit = { |
|
"sample_weight": None, |
|
"my_param": None, |
|
} |
|
__metadata_request__score = { |
|
"sample_weight": None, |
|
"my_param": True, |
|
"my_other_param": None, |
|
} |
|
|
|
__metadata_request__other_method = {"my_param": True} |
|
|
|
class TestDefaults(_MetadataRequester): |
|
__metadata_request__fit = { |
|
"sample_weight": None, |
|
"my_other_param": None, |
|
} |
|
__metadata_request__score = { |
|
"sample_weight": None, |
|
"my_param": True, |
|
"my_other_param": None, |
|
} |
|
__metadata_request__predict = {"my_param": True} |
|
|
|
with pytest.raises( |
|
AttributeError, match="'MetadataRequest' object has no attribute 'other_method'" |
|
): |
|
TestDefaultsBadMethodName().get_metadata_routing() |
|
|
|
expected = { |
|
"score": { |
|
"my_param": True, |
|
"my_other_param": None, |
|
"sample_weight": None, |
|
}, |
|
"fit": { |
|
"my_other_param": None, |
|
"sample_weight": None, |
|
}, |
|
"predict": {"my_param": True}, |
|
} |
|
assert_request_equal(TestDefaults().get_metadata_routing(), expected) |
|
|
|
est = TestDefaults().set_score_request(my_param="other_param") |
|
expected = { |
|
"score": { |
|
"my_param": "other_param", |
|
"my_other_param": None, |
|
"sample_weight": None, |
|
}, |
|
"fit": { |
|
"my_other_param": None, |
|
"sample_weight": None, |
|
}, |
|
"predict": {"my_param": True}, |
|
} |
|
assert_request_equal(est.get_metadata_routing(), expected) |
|
|
|
est = TestDefaults().set_fit_request(sample_weight=True) |
|
expected = { |
|
"score": { |
|
"my_param": True, |
|
"my_other_param": None, |
|
"sample_weight": None, |
|
}, |
|
"fit": { |
|
"my_other_param": None, |
|
"sample_weight": True, |
|
}, |
|
"predict": {"my_param": True}, |
|
} |
|
assert_request_equal(est.get_metadata_routing(), expected) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_setting_default_requests(): |
|
|
|
test_cases = dict() |
|
|
|
class ExplicitRequest(BaseEstimator): |
|
|
|
__metadata_request__fit = {"prop": None} |
|
|
|
def fit(self, X, y, **kwargs): |
|
return self |
|
|
|
test_cases[ExplicitRequest] = {"prop": None} |
|
|
|
class ExplicitRequestOverwrite(BaseEstimator): |
|
|
|
|
|
__metadata_request__fit = {"prop": True} |
|
|
|
def fit(self, X, y, prop=None, **kwargs): |
|
return self |
|
|
|
test_cases[ExplicitRequestOverwrite] = {"prop": True} |
|
|
|
class ImplicitRequest(BaseEstimator): |
|
|
|
def fit(self, X, y, prop=None, **kwargs): |
|
return self |
|
|
|
test_cases[ImplicitRequest] = {"prop": None} |
|
|
|
class ImplicitRequestRemoval(BaseEstimator): |
|
|
|
|
|
__metadata_request__fit = {"prop": metadata_routing.UNUSED} |
|
|
|
def fit(self, X, y, prop=None, **kwargs): |
|
return self |
|
|
|
test_cases[ImplicitRequestRemoval] = {} |
|
|
|
for Klass, requests in test_cases.items(): |
|
assert get_routing_for_object(Klass()).fit.requests == requests |
|
assert_request_is_empty(Klass().get_metadata_routing(), exclude="fit") |
|
Klass().fit(None, None) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_removing_non_existing_param_raises(): |
|
"""Test that removing a metadata using UNUSED which doesn't exist raises.""" |
|
|
|
class InvalidRequestRemoval(BaseEstimator): |
|
|
|
|
|
__metadata_request__fit = {"prop": metadata_routing.UNUSED} |
|
|
|
def fit(self, X, y, **kwargs): |
|
return self |
|
|
|
with pytest.raises(ValueError, match="Trying to remove parameter"): |
|
InvalidRequestRemoval().get_metadata_routing() |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_method_metadata_request(): |
|
mmr = MethodMetadataRequest(owner="test", method="fit") |
|
|
|
with pytest.raises(ValueError, match="The alias you're setting for"): |
|
mmr.add_request(param="foo", alias=1.4) |
|
|
|
mmr.add_request(param="foo", alias=None) |
|
assert mmr.requests == {"foo": None} |
|
mmr.add_request(param="foo", alias=False) |
|
assert mmr.requests == {"foo": False} |
|
mmr.add_request(param="foo", alias=True) |
|
assert mmr.requests == {"foo": True} |
|
mmr.add_request(param="foo", alias="foo") |
|
assert mmr.requests == {"foo": True} |
|
mmr.add_request(param="foo", alias="bar") |
|
assert mmr.requests == {"foo": "bar"} |
|
assert mmr._get_param_names(return_alias=False) == {"foo"} |
|
assert mmr._get_param_names(return_alias=True) == {"bar"} |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_get_routing_for_object(): |
|
class Consumer(BaseEstimator): |
|
__metadata_request__fit = {"prop": None} |
|
|
|
assert_request_is_empty(get_routing_for_object(None)) |
|
assert_request_is_empty(get_routing_for_object(object())) |
|
|
|
mr = MetadataRequest(owner="test") |
|
mr.fit.add_request(param="foo", alias="bar") |
|
mr_factory = get_routing_for_object(mr) |
|
assert_request_is_empty(mr_factory, exclude="fit") |
|
assert mr_factory.fit.requests == {"foo": "bar"} |
|
|
|
mr = get_routing_for_object(Consumer()) |
|
assert_request_is_empty(mr, exclude="fit") |
|
assert mr.fit.requests == {"prop": None} |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_metadata_request_consumes_method(): |
|
"""Test that MetadataRequest().consumes() method works as expected.""" |
|
request = MetadataRouter(owner="test") |
|
assert request.consumes(method="fit", params={"foo"}) == set() |
|
|
|
request = MetadataRequest(owner="test") |
|
request.fit.add_request(param="foo", alias=True) |
|
assert request.consumes(method="fit", params={"foo"}) == {"foo"} |
|
|
|
request = MetadataRequest(owner="test") |
|
request.fit.add_request(param="foo", alias="bar") |
|
assert request.consumes(method="fit", params={"bar", "foo"}) == {"bar"} |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_metadata_router_consumes_method(): |
|
"""Test that MetadataRouter().consumes method works as expected.""" |
|
|
|
|
|
cases = [ |
|
( |
|
WeightedMetaRegressor( |
|
estimator=ConsumingRegressor().set_fit_request(sample_weight=True) |
|
), |
|
{"sample_weight"}, |
|
{"sample_weight"}, |
|
), |
|
( |
|
WeightedMetaRegressor( |
|
estimator=ConsumingRegressor().set_fit_request( |
|
sample_weight="my_weights" |
|
) |
|
), |
|
{"my_weights", "sample_weight"}, |
|
{"my_weights"}, |
|
), |
|
] |
|
|
|
for obj, input, output in cases: |
|
assert obj.get_metadata_routing().consumes(method="fit", params=input) == output |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_metaestimator_warnings(): |
|
class WeightedMetaRegressorWarn(WeightedMetaRegressor): |
|
__metadata_request__fit = {"sample_weight": metadata_routing.WARN} |
|
|
|
with pytest.warns( |
|
UserWarning, match="Support for .* has recently been added to this class" |
|
): |
|
WeightedMetaRegressorWarn( |
|
estimator=LinearRegression().set_fit_request(sample_weight=False) |
|
).fit(X, y, sample_weight=my_weights) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_estimator_warnings(): |
|
class ConsumingRegressorWarn(ConsumingRegressor): |
|
__metadata_request__fit = {"sample_weight": metadata_routing.WARN} |
|
|
|
with pytest.warns( |
|
UserWarning, match="Support for .* has recently been added to this class" |
|
): |
|
MetaRegressor(estimator=ConsumingRegressorWarn()).fit( |
|
X, y, sample_weight=my_weights |
|
) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
@pytest.mark.parametrize( |
|
"obj, string", |
|
[ |
|
( |
|
MethodMetadataRequest(owner="test", method="fit").add_request( |
|
param="foo", alias="bar" |
|
), |
|
"{'foo': 'bar'}", |
|
), |
|
( |
|
MetadataRequest(owner="test"), |
|
"{}", |
|
), |
|
( |
|
MetadataRouter(owner="test").add( |
|
estimator=ConsumingRegressor(), |
|
method_mapping=MethodMapping().add(caller="predict", callee="predict"), |
|
), |
|
( |
|
"{'estimator': {'mapping': [{'caller': 'predict', 'callee':" |
|
" 'predict'}], 'router': {'fit': {'sample_weight': None, 'metadata':" |
|
" None}, 'partial_fit': {'sample_weight': None, 'metadata': None}," |
|
" 'predict': {'sample_weight': None, 'metadata': None}, 'score':" |
|
" {'sample_weight': None, 'metadata': None}}}}" |
|
), |
|
), |
|
], |
|
) |
|
@config_context(enable_metadata_routing=True) |
|
def test_string_representations(obj, string): |
|
assert str(obj) == string |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"obj, method, inputs, err_cls, err_msg", |
|
[ |
|
( |
|
MethodMapping(), |
|
"add", |
|
{"caller": "fit", "callee": "invalid"}, |
|
ValueError, |
|
"Given callee", |
|
), |
|
( |
|
MethodMapping(), |
|
"add", |
|
{"caller": "invalid", "callee": "fit"}, |
|
ValueError, |
|
"Given caller", |
|
), |
|
( |
|
MetadataRouter(owner="test"), |
|
"add_self_request", |
|
{"obj": MetadataRouter(owner="test")}, |
|
ValueError, |
|
"Given `obj` is neither a `MetadataRequest` nor does it implement", |
|
), |
|
( |
|
ConsumingClassifier(), |
|
"set_fit_request", |
|
{"invalid": True}, |
|
TypeError, |
|
"Unexpected args", |
|
), |
|
], |
|
) |
|
@config_context(enable_metadata_routing=True) |
|
def test_validations(obj, method, inputs, err_cls, err_msg): |
|
with pytest.raises(err_cls, match=err_msg): |
|
getattr(obj, method)(**inputs) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_methodmapping(): |
|
mm = ( |
|
MethodMapping() |
|
.add(caller="fit", callee="transform") |
|
.add(caller="fit", callee="fit") |
|
) |
|
|
|
mm_list = list(mm) |
|
assert mm_list[0] == ("fit", "transform") |
|
assert mm_list[1] == ("fit", "fit") |
|
|
|
mm = MethodMapping() |
|
for method in METHODS: |
|
mm.add(caller=method, callee=method) |
|
assert MethodPair(method, method) in mm._routes |
|
assert len(mm._routes) == len(METHODS) |
|
|
|
mm = MethodMapping().add(caller="score", callee="score") |
|
assert repr(mm) == "[{'caller': 'score', 'callee': 'score'}]" |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_metadatarouter_add_self_request(): |
|
|
|
request = MetadataRequest(owner="nested") |
|
request.fit.add_request(param="param", alias=True) |
|
router = MetadataRouter(owner="test").add_self_request(request) |
|
assert str(router._self_request) == str(request) |
|
|
|
assert router._self_request is not request |
|
|
|
|
|
est = ConsumingRegressor().set_fit_request(sample_weight="my_weights") |
|
router = MetadataRouter(owner="test").add_self_request(obj=est) |
|
assert str(router._self_request) == str(est.get_metadata_routing()) |
|
assert router._self_request is not est.get_metadata_routing() |
|
|
|
|
|
est = WeightedMetaRegressor( |
|
estimator=ConsumingRegressor().set_fit_request(sample_weight="nested_weights") |
|
) |
|
router = MetadataRouter(owner="test").add_self_request(obj=est) |
|
|
|
assert str(router._self_request) == str(est._get_metadata_request()) |
|
|
|
|
|
assert str(router._self_request) != str(est.get_metadata_routing()) |
|
|
|
assert router._self_request is not est._get_metadata_request() |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_metadata_routing_add(): |
|
|
|
router = MetadataRouter(owner="test").add( |
|
est=ConsumingRegressor().set_fit_request(sample_weight="weights"), |
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
|
) |
|
assert ( |
|
str(router) |
|
== "{'est': {'mapping': [{'caller': 'fit', 'callee': 'fit'}], 'router': {'fit':" |
|
" {'sample_weight': 'weights', 'metadata': None}, 'partial_fit':" |
|
" {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" |
|
" None, 'metadata': None}, 'score': {'sample_weight': None, 'metadata':" |
|
" None}}}}" |
|
) |
|
|
|
|
|
router = MetadataRouter(owner="test").add( |
|
method_mapping=MethodMapping().add(caller="fit", callee="score"), |
|
est=ConsumingRegressor().set_score_request(sample_weight=True), |
|
) |
|
assert ( |
|
str(router) |
|
== "{'est': {'mapping': [{'caller': 'fit', 'callee': 'score'}], 'router':" |
|
" {'fit': {'sample_weight': None, 'metadata': None}, 'partial_fit':" |
|
" {'sample_weight': None, 'metadata': None}, 'predict': {'sample_weight':" |
|
" None, 'metadata': None}, 'score': {'sample_weight': True, 'metadata':" |
|
" None}}}}" |
|
) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_metadata_routing_get_param_names(): |
|
router = ( |
|
MetadataRouter(owner="test") |
|
.add_self_request( |
|
WeightedMetaRegressor(estimator=ConsumingRegressor()).set_fit_request( |
|
sample_weight="self_weights" |
|
) |
|
) |
|
.add( |
|
trs=ConsumingTransformer().set_fit_request( |
|
sample_weight="transform_weights" |
|
), |
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
|
) |
|
) |
|
|
|
assert ( |
|
str(router) |
|
== "{'$self_request': {'fit': {'sample_weight': 'self_weights'}, 'score':" |
|
" {'sample_weight': None}}, 'trs': {'mapping': [{'caller': 'fit', 'callee':" |
|
" 'fit'}], 'router': {'fit': {'sample_weight': 'transform_weights'," |
|
" 'metadata': None}, 'transform': {'sample_weight': None, 'metadata': None}," |
|
" 'inverse_transform': {'sample_weight': None, 'metadata': None}}}}" |
|
) |
|
|
|
assert router._get_param_names( |
|
method="fit", return_alias=True, ignore_self_request=False |
|
) == {"transform_weights", "metadata", "self_weights"} |
|
|
|
assert router._get_param_names( |
|
method="fit", return_alias=False, ignore_self_request=False |
|
) == {"sample_weight", "metadata", "transform_weights"} |
|
|
|
assert router._get_param_names( |
|
method="fit", return_alias=False, ignore_self_request=True |
|
) == {"metadata", "transform_weights"} |
|
|
|
assert router._get_param_names( |
|
method="fit", return_alias=True, ignore_self_request=True |
|
) == router._get_param_names( |
|
method="fit", return_alias=False, ignore_self_request=True |
|
) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_method_generation(): |
|
|
|
|
|
|
|
|
|
class SimpleEstimator(BaseEstimator): |
|
|
|
def fit(self, X, y): |
|
pass |
|
|
|
def fit_transform(self, X, y): |
|
pass |
|
|
|
def fit_predict(self, X, y): |
|
pass |
|
|
|
def partial_fit(self, X, y): |
|
pass |
|
|
|
def predict(self, X): |
|
pass |
|
|
|
def predict_proba(self, X): |
|
pass |
|
|
|
def predict_log_proba(self, X): |
|
pass |
|
|
|
def decision_function(self, X): |
|
pass |
|
|
|
def score(self, X, y): |
|
pass |
|
|
|
def split(self, X, y=None): |
|
pass |
|
|
|
def transform(self, X): |
|
pass |
|
|
|
def inverse_transform(self, X): |
|
pass |
|
|
|
for method in METHODS: |
|
assert not hasattr(SimpleEstimator(), f"set_{method}_request") |
|
|
|
class SimpleEstimator(BaseEstimator): |
|
|
|
def fit(self, X, y, sample_weight=None): |
|
pass |
|
|
|
def fit_transform(self, X, y, sample_weight=None): |
|
pass |
|
|
|
def fit_predict(self, X, y, sample_weight=None): |
|
pass |
|
|
|
def partial_fit(self, X, y, sample_weight=None): |
|
pass |
|
|
|
def predict(self, X, sample_weight=None): |
|
pass |
|
|
|
def predict_proba(self, X, sample_weight=None): |
|
pass |
|
|
|
def predict_log_proba(self, X, sample_weight=None): |
|
pass |
|
|
|
def decision_function(self, X, sample_weight=None): |
|
pass |
|
|
|
def score(self, X, y, sample_weight=None): |
|
pass |
|
|
|
def split(self, X, y=None, sample_weight=None): |
|
pass |
|
|
|
def transform(self, X, sample_weight=None): |
|
pass |
|
|
|
def inverse_transform(self, X, sample_weight=None): |
|
pass |
|
|
|
|
|
for method in COMPOSITE_METHODS: |
|
assert not hasattr(SimpleEstimator(), f"set_{method}_request") |
|
|
|
|
|
for method in SIMPLE_METHODS: |
|
assert hasattr(SimpleEstimator(), f"set_{method}_request") |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_composite_methods(): |
|
|
|
|
|
|
|
|
|
class SimpleEstimator(BaseEstimator): |
|
|
|
def fit(self, X, y, foo=None, bar=None): |
|
pass |
|
|
|
def predict(self, X, foo=None, bar=None): |
|
pass |
|
|
|
def transform(self, X, other_param=None): |
|
pass |
|
|
|
est = SimpleEstimator() |
|
|
|
|
|
assert est.get_metadata_routing().fit_transform.requests == { |
|
"bar": None, |
|
"foo": None, |
|
"other_param": None, |
|
} |
|
assert est.get_metadata_routing().fit_predict.requests == {"bar": None, "foo": None} |
|
|
|
|
|
est.set_fit_request(foo=True, bar="test") |
|
with pytest.raises(ValueError, match="Conflicting metadata requests for"): |
|
est.get_metadata_routing().fit_predict |
|
|
|
|
|
|
|
est.set_predict_request(bar=True) |
|
with pytest.raises(ValueError, match="Conflicting metadata requests for"): |
|
est.get_metadata_routing().fit_predict |
|
|
|
|
|
|
|
est.set_predict_request(foo=True, bar="test") |
|
est.get_metadata_routing().fit_predict |
|
|
|
|
|
|
|
est.set_transform_request(other_param=True) |
|
assert est.get_metadata_routing().fit_transform.requests == { |
|
"bar": "test", |
|
"foo": True, |
|
"other_param": True, |
|
} |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_no_feature_flag_raises_error(): |
|
"""Test that when feature flag disabled, set_{method}_requests raises.""" |
|
with config_context(enable_metadata_routing=False): |
|
with pytest.raises(RuntimeError, match="This method is only available"): |
|
ConsumingClassifier().set_fit_request(sample_weight=True) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_none_metadata_passed(): |
|
"""Test that passing None as metadata when not requested doesn't raise""" |
|
MetaRegressor(estimator=ConsumingRegressor()).fit(X, y, sample_weight=None) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_no_metadata_always_works(): |
|
"""Test that when no metadata is passed, having a meta-estimator which does |
|
not yet support metadata routing works. |
|
|
|
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28246 |
|
""" |
|
|
|
class Estimator(_RoutingNotSupportedMixin, BaseEstimator): |
|
def fit(self, X, y, metadata=None): |
|
return self |
|
|
|
|
|
MetaRegressor(estimator=Estimator()).fit(X, y) |
|
|
|
with pytest.raises( |
|
NotImplementedError, match="Estimator has not implemented metadata routing yet." |
|
): |
|
MetaRegressor(estimator=Estimator()).fit(X, y, metadata=my_groups) |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_unsetmetadatapassederror_correct(): |
|
"""Test that UnsetMetadataPassedError raises the correct error message when |
|
set_{method}_request is not set in nested cases.""" |
|
weighted_meta = WeightedMetaClassifier(estimator=ConsumingClassifier()) |
|
pipe = SimplePipeline([weighted_meta]) |
|
msg = re.escape( |
|
"[metadata] are passed but are not explicitly set as requested or not requested" |
|
" for ConsumingClassifier.fit, which is used within WeightedMetaClassifier.fit." |
|
" Call `ConsumingClassifier.set_fit_request({metadata}=True/False)` for each" |
|
" metadata you want to request/ignore." |
|
) |
|
|
|
with pytest.raises(UnsetMetadataPassedError, match=msg): |
|
pipe.fit(X, y, metadata="blah") |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_unsetmetadatapassederror_correct_for_composite_methods(): |
|
"""Test that UnsetMetadataPassedError raises the correct error message when |
|
composite metadata request methods are not set in nested cases.""" |
|
consuming_transformer = ConsumingTransformer() |
|
pipe = Pipeline([("consuming_transformer", consuming_transformer)]) |
|
|
|
msg = re.escape( |
|
"[metadata] are passed but are not explicitly set as requested or not requested" |
|
" for ConsumingTransformer.fit_transform, which is used within" |
|
" Pipeline.fit_transform. Call" |
|
" `ConsumingTransformer.set_fit_request({metadata}=True/False)" |
|
".set_transform_request({metadata}=True/False)`" |
|
" for each metadata you want to request/ignore." |
|
) |
|
with pytest.raises(UnsetMetadataPassedError, match=msg): |
|
pipe.fit_transform(X, y, metadata="blah") |
|
|
|
|
|
@config_context(enable_metadata_routing=True) |
|
def test_unbound_set_methods_work(): |
|
"""Tests that if the set_{method}_request is unbound, it still works. |
|
|
|
Also test that passing positional arguments to the set_{method}_request fails |
|
with the right TypeError message. |
|
|
|
Non-regression test for https://github.com/scikit-learn/scikit-learn/issues/28632 |
|
""" |
|
|
|
class A(BaseEstimator): |
|
def fit(self, X, y, sample_weight=None): |
|
return self |
|
|
|
error_message = re.escape( |
|
"set_fit_request() takes 0 positional argument but 1 were given" |
|
) |
|
|
|
|
|
with pytest.raises(TypeError, match=error_message): |
|
A().set_fit_request(True) |
|
|
|
|
|
|
|
|
|
A.set_fit_request = A.set_fit_request |
|
|
|
|
|
A().set_fit_request(sample_weight=True) |
|
|
|
|
|
with pytest.raises(TypeError, match=error_message): |
|
A().set_fit_request(True) |
|
|