|
import inspect |
|
from collections import defaultdict |
|
from functools import partial |
|
|
|
import numpy as np |
|
from numpy.testing import assert_array_equal |
|
|
|
from sklearn.base import ( |
|
BaseEstimator, |
|
ClassifierMixin, |
|
MetaEstimatorMixin, |
|
RegressorMixin, |
|
TransformerMixin, |
|
clone, |
|
) |
|
from sklearn.metrics._scorer import _Scorer, mean_squared_error |
|
from sklearn.model_selection import BaseCrossValidator |
|
from sklearn.model_selection._split import GroupsConsumerMixin |
|
from sklearn.utils._metadata_requests import ( |
|
SIMPLE_METHODS, |
|
) |
|
from sklearn.utils.metadata_routing import ( |
|
MetadataRouter, |
|
MethodMapping, |
|
process_routing, |
|
) |
|
from sklearn.utils.multiclass import _check_partial_fit_first_call |
|
|
|
|
|
def record_metadata(obj, record_default=True, **kwargs): |
|
"""Utility function to store passed metadata to a method of obj. |
|
|
|
If record_default is False, kwargs whose values are "default" are skipped. |
|
This is so that checks on keyword arguments whose default was not changed |
|
are skipped. |
|
|
|
""" |
|
stack = inspect.stack() |
|
callee = stack[1].function |
|
caller = stack[2].function |
|
if not hasattr(obj, "_records"): |
|
obj._records = defaultdict(lambda: defaultdict(list)) |
|
if not record_default: |
|
kwargs = { |
|
key: val |
|
for key, val in kwargs.items() |
|
if not isinstance(val, str) or (val != "default") |
|
} |
|
obj._records[callee][caller].append(kwargs) |
|
|
|
|
|
def check_recorded_metadata(obj, method, parent, split_params=tuple(), **kwargs): |
|
"""Check whether the expected metadata is passed to the object's method. |
|
|
|
Parameters |
|
---------- |
|
obj : estimator object |
|
sub-estimator to check routed params for |
|
method : str |
|
sub-estimator's method where metadata is routed to, or otherwise in |
|
the context of metadata routing referred to as 'callee' |
|
parent : str |
|
the parent method which should have called `method`, or otherwise in |
|
the context of metadata routing referred to as 'caller' |
|
split_params : tuple, default=empty |
|
specifies any parameters which are to be checked as being a subset |
|
of the original values |
|
**kwargs : dict |
|
passed metadata |
|
""" |
|
all_records = ( |
|
getattr(obj, "_records", dict()).get(method, dict()).get(parent, list()) |
|
) |
|
for record in all_records: |
|
|
|
|
|
assert set(kwargs.keys()) == set( |
|
record.keys() |
|
), f"Expected {kwargs.keys()} vs {record.keys()}" |
|
for key, value in kwargs.items(): |
|
recorded_value = record[key] |
|
|
|
|
|
if key in split_params and recorded_value is not None: |
|
assert np.isin(recorded_value, value).all() |
|
else: |
|
if isinstance(recorded_value, np.ndarray): |
|
assert_array_equal(recorded_value, value) |
|
else: |
|
assert ( |
|
recorded_value is value |
|
), f"Expected {recorded_value} vs {value}. Method: {method}" |
|
|
|
|
|
record_metadata_not_default = partial(record_metadata, record_default=False) |
|
|
|
|
|
def assert_request_is_empty(metadata_request, exclude=None): |
|
"""Check if a metadata request dict is empty. |
|
|
|
One can exclude a method or a list of methods from the check using the |
|
``exclude`` parameter. If metadata_request is a MetadataRouter, then |
|
``exclude`` can be of the form ``{"object" : [method, ...]}``. |
|
""" |
|
if isinstance(metadata_request, MetadataRouter): |
|
for name, route_mapping in metadata_request: |
|
if exclude is not None and name in exclude: |
|
_exclude = exclude[name] |
|
else: |
|
_exclude = None |
|
assert_request_is_empty(route_mapping.router, exclude=_exclude) |
|
return |
|
|
|
exclude = [] if exclude is None else exclude |
|
for method in SIMPLE_METHODS: |
|
if method in exclude: |
|
continue |
|
mmr = getattr(metadata_request, method) |
|
props = [ |
|
prop |
|
for prop, alias in mmr.requests.items() |
|
if isinstance(alias, str) or alias is not None |
|
] |
|
assert not props |
|
|
|
|
|
def assert_request_equal(request, dictionary): |
|
for method, requests in dictionary.items(): |
|
mmr = getattr(request, method) |
|
assert mmr.requests == requests |
|
|
|
empty_methods = [method for method in SIMPLE_METHODS if method not in dictionary] |
|
for method in empty_methods: |
|
assert not len(getattr(request, method).requests) |
|
|
|
|
|
class _Registry(list): |
|
|
|
|
|
|
|
|
|
|
|
def __deepcopy__(self, memo): |
|
return self |
|
|
|
def __copy__(self): |
|
return self |
|
|
|
|
|
class ConsumingRegressor(RegressorMixin, BaseEstimator): |
|
"""A regressor consuming metadata. |
|
|
|
Parameters |
|
---------- |
|
registry : list, default=None |
|
If a list, the estimator will append itself to the list in order to have |
|
a reference to the estimator later on. Since that reference is not |
|
required in all tests, registration can be skipped by leaving this value |
|
as None. |
|
""" |
|
|
|
def __init__(self, registry=None): |
|
self.registry = registry |
|
|
|
def partial_fit(self, X, y, sample_weight="default", metadata="default"): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return self |
|
|
|
def fit(self, X, y, sample_weight="default", metadata="default"): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return self |
|
|
|
def predict(self, X, y=None, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return np.zeros(shape=(len(X),)) |
|
|
|
def score(self, X, y, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return 1 |
|
|
|
|
|
class NonConsumingClassifier(ClassifierMixin, BaseEstimator): |
|
"""A classifier which accepts no metadata on any method.""" |
|
|
|
def __init__(self, alpha=0.0): |
|
self.alpha = alpha |
|
|
|
def fit(self, X, y): |
|
self.classes_ = np.unique(y) |
|
self.coef_ = np.ones_like(X) |
|
return self |
|
|
|
def partial_fit(self, X, y, classes=None): |
|
return self |
|
|
|
def decision_function(self, X): |
|
return self.predict(X) |
|
|
|
def predict(self, X): |
|
y_pred = np.empty(shape=(len(X),)) |
|
y_pred[: len(X) // 2] = 0 |
|
y_pred[len(X) // 2 :] = 1 |
|
return y_pred |
|
|
|
def predict_proba(self, X): |
|
|
|
y_proba = np.empty(shape=(len(X), 2)) |
|
y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) |
|
y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) |
|
return y_proba |
|
|
|
def predict_log_proba(self, X): |
|
|
|
return self.predict_proba(X) |
|
|
|
|
|
class NonConsumingRegressor(RegressorMixin, BaseEstimator): |
|
"""A classifier which accepts no metadata on any method.""" |
|
|
|
def fit(self, X, y): |
|
return self |
|
|
|
def partial_fit(self, X, y): |
|
return self |
|
|
|
def predict(self, X): |
|
return np.ones(len(X)) |
|
|
|
|
|
class ConsumingClassifier(ClassifierMixin, BaseEstimator): |
|
"""A classifier consuming metadata. |
|
|
|
Parameters |
|
---------- |
|
registry : list, default=None |
|
If a list, the estimator will append itself to the list in order to have |
|
a reference to the estimator later on. Since that reference is not |
|
required in all tests, registration can be skipped by leaving this value |
|
as None. |
|
|
|
alpha : float, default=0 |
|
This parameter is only used to test the ``*SearchCV`` objects, and |
|
doesn't do anything. |
|
""" |
|
|
|
def __init__(self, registry=None, alpha=0.0): |
|
self.alpha = alpha |
|
self.registry = registry |
|
|
|
def partial_fit( |
|
self, X, y, classes=None, sample_weight="default", metadata="default" |
|
): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
_check_partial_fit_first_call(self, classes) |
|
return self |
|
|
|
def fit(self, X, y, sample_weight="default", metadata="default"): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
|
|
self.classes_ = np.unique(y) |
|
self.coef_ = np.ones_like(X) |
|
return self |
|
|
|
def predict(self, X, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
y_score = np.empty(shape=(len(X),), dtype="int8") |
|
y_score[len(X) // 2 :] = 0 |
|
y_score[: len(X) // 2] = 1 |
|
return y_score |
|
|
|
def predict_proba(self, X, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
y_proba = np.empty(shape=(len(X), 2)) |
|
y_proba[: len(X) // 2, :] = np.asarray([1.0, 0.0]) |
|
y_proba[len(X) // 2 :, :] = np.asarray([0.0, 1.0]) |
|
return y_proba |
|
|
|
def predict_log_proba(self, X, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return np.zeros(shape=(len(X), 2)) |
|
|
|
def decision_function(self, X, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
y_score = np.empty(shape=(len(X),)) |
|
y_score[len(X) // 2 :] = 0 |
|
y_score[: len(X) // 2] = 1 |
|
return y_score |
|
|
|
def score(self, X, y, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return 1 |
|
|
|
|
|
class ConsumingTransformer(TransformerMixin, BaseEstimator): |
|
"""A transformer which accepts metadata on fit and transform. |
|
|
|
Parameters |
|
---------- |
|
registry : list, default=None |
|
If a list, the estimator will append itself to the list in order to have |
|
a reference to the estimator later on. Since that reference is not |
|
required in all tests, registration can be skipped by leaving this value |
|
as None. |
|
""" |
|
|
|
def __init__(self, registry=None): |
|
self.registry = registry |
|
|
|
def fit(self, X, y=None, sample_weight="default", metadata="default"): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
self.fitted_ = True |
|
return self |
|
|
|
def transform(self, X, sample_weight="default", metadata="default"): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return X + 1 |
|
|
|
def fit_transform(self, X, y, sample_weight="default", metadata="default"): |
|
|
|
|
|
|
|
|
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return self.fit(X, y, sample_weight=sample_weight, metadata=metadata).transform( |
|
X, sample_weight=sample_weight, metadata=metadata |
|
) |
|
|
|
def inverse_transform(self, X, sample_weight=None, metadata=None): |
|
record_metadata_not_default( |
|
self, sample_weight=sample_weight, metadata=metadata |
|
) |
|
return X - 1 |
|
|
|
|
|
class ConsumingNoFitTransformTransformer(BaseEstimator): |
|
"""A metadata consuming transformer that doesn't inherit from |
|
TransformerMixin, and thus doesn't implement `fit_transform`. Note that |
|
TransformerMixin's `fit_transform` doesn't route metadata to `transform`.""" |
|
|
|
def __init__(self, registry=None): |
|
self.registry = registry |
|
|
|
def fit(self, X, y=None, sample_weight=None, metadata=None): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata(self, sample_weight=sample_weight, metadata=metadata) |
|
|
|
return self |
|
|
|
def transform(self, X, sample_weight=None, metadata=None): |
|
record_metadata(self, sample_weight=sample_weight, metadata=metadata) |
|
return X |
|
|
|
|
|
class ConsumingScorer(_Scorer): |
|
def __init__(self, registry=None): |
|
super().__init__( |
|
score_func=mean_squared_error, sign=1, kwargs={}, response_method="predict" |
|
) |
|
self.registry = registry |
|
|
|
def _score(self, method_caller, clf, X, y, **kwargs): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default(self, **kwargs) |
|
|
|
sample_weight = kwargs.get("sample_weight", None) |
|
return super()._score(method_caller, clf, X, y, sample_weight=sample_weight) |
|
|
|
|
|
class ConsumingSplitter(GroupsConsumerMixin, BaseCrossValidator): |
|
def __init__(self, registry=None): |
|
self.registry = registry |
|
|
|
def split(self, X, y=None, groups="default", metadata="default"): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata_not_default(self, groups=groups, metadata=metadata) |
|
|
|
split_index = len(X) // 2 |
|
train_indices = list(range(0, split_index)) |
|
test_indices = list(range(split_index, len(X))) |
|
yield test_indices, train_indices |
|
yield train_indices, test_indices |
|
|
|
def get_n_splits(self, X=None, y=None, groups=None, metadata=None): |
|
return 2 |
|
|
|
def _iter_test_indices(self, X=None, y=None, groups=None): |
|
split_index = len(X) // 2 |
|
train_indices = list(range(0, split_index)) |
|
test_indices = list(range(split_index, len(X))) |
|
yield test_indices |
|
yield train_indices |
|
|
|
|
|
class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): |
|
"""A meta-regressor which is only a router.""" |
|
|
|
def __init__(self, estimator): |
|
self.estimator = estimator |
|
|
|
def fit(self, X, y, **fit_params): |
|
params = process_routing(self, "fit", **fit_params) |
|
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) |
|
|
|
def get_metadata_routing(self): |
|
router = MetadataRouter(owner=self.__class__.__name__).add( |
|
estimator=self.estimator, |
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
|
) |
|
return router |
|
|
|
|
|
class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): |
|
"""A meta-regressor which is also a consumer.""" |
|
|
|
def __init__(self, estimator, registry=None): |
|
self.estimator = estimator |
|
self.registry = registry |
|
|
|
def fit(self, X, y, sample_weight=None, **fit_params): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata(self, sample_weight=sample_weight) |
|
params = process_routing(self, "fit", sample_weight=sample_weight, **fit_params) |
|
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) |
|
return self |
|
|
|
def predict(self, X, **predict_params): |
|
params = process_routing(self, "predict", **predict_params) |
|
return self.estimator_.predict(X, **params.estimator.predict) |
|
|
|
def get_metadata_routing(self): |
|
router = ( |
|
MetadataRouter(owner=self.__class__.__name__) |
|
.add_self_request(self) |
|
.add( |
|
estimator=self.estimator, |
|
method_mapping=MethodMapping() |
|
.add(caller="fit", callee="fit") |
|
.add(caller="predict", callee="predict"), |
|
) |
|
) |
|
return router |
|
|
|
|
|
class WeightedMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): |
|
"""A meta-estimator which also consumes sample_weight itself in ``fit``.""" |
|
|
|
def __init__(self, estimator, registry=None): |
|
self.estimator = estimator |
|
self.registry = registry |
|
|
|
def fit(self, X, y, sample_weight=None, **kwargs): |
|
if self.registry is not None: |
|
self.registry.append(self) |
|
|
|
record_metadata(self, sample_weight=sample_weight) |
|
params = process_routing(self, "fit", sample_weight=sample_weight, **kwargs) |
|
self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) |
|
return self |
|
|
|
def get_metadata_routing(self): |
|
router = ( |
|
MetadataRouter(owner=self.__class__.__name__) |
|
.add_self_request(self) |
|
.add( |
|
estimator=self.estimator, |
|
method_mapping=MethodMapping().add(caller="fit", callee="fit"), |
|
) |
|
) |
|
return router |
|
|
|
|
|
class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): |
|
"""A simple meta-transformer.""" |
|
|
|
def __init__(self, transformer): |
|
self.transformer = transformer |
|
|
|
def fit(self, X, y=None, **fit_params): |
|
params = process_routing(self, "fit", **fit_params) |
|
self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) |
|
return self |
|
|
|
def transform(self, X, y=None, **transform_params): |
|
params = process_routing(self, "transform", **transform_params) |
|
return self.transformer_.transform(X, **params.transformer.transform) |
|
|
|
def get_metadata_routing(self): |
|
return MetadataRouter(owner=self.__class__.__name__).add( |
|
transformer=self.transformer, |
|
method_mapping=MethodMapping() |
|
.add(caller="fit", callee="fit") |
|
.add(caller="transform", callee="transform"), |
|
) |
|
|