FairUP / src /aif360 /algorithms /transformer.py
erasmopurif's picture
First commit
d2a8669
from abc import abstractmethod
from functools import wraps
from aif360.datasets import Dataset
from aif360.decorating_metaclass import ApplyDecorator
# TODO: Use sklearn.exceptions.NotFittedError instead?
class NotFittedError(ValueError, AttributeError):
"""Error to be raised if `predict` or `transform` is called before `fit`."""
def addmetadata(func):
"""Decorator for instance methods which perform a transformation and return
a new dataset.
Automatically populates the `metadata` field of the new dataset to reflect
details of the transformation that occurred, e.g.::
{
'transformer': 'TransformerClass.function_name',
'params': kwargs_from_init,
'previous': [all_datasets_used_by_func]
}
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
new_dataset = func(self, *args, **kwargs)
if isinstance(new_dataset, Dataset):
new_dataset.metadata = new_dataset.metadata.copy()
new_dataset.metadata.update({
'transformer': '{}.{}'.format(type(self).__name__, func.__name__),
'params': self._params,
'previous': [a for a in args if isinstance(a, Dataset)]
})
return new_dataset
return wrapper
BaseClass = ApplyDecorator(addmetadata)
class Transformer(BaseClass):
"""Abstract base class for transformers.
Transformers are an abstraction for any process which acts on a
:obj:`Dataset` and returns a new, modified Dataset. This definition
encompasses pre-processing, in-processing, and post-processing algorithms.
"""
@abstractmethod
def __init__(self, **kwargs):
"""Initialize a Transformer object.
Algorithm-specific configuration parameters should be passed here.
"""
self._params = kwargs
def fit(self, dataset):
"""Train a model on the input.
Args:
dataset (Dataset): Input dataset.
Returns:
Transformer: Returns self.
"""
return self
def predict(self, dataset):
"""Return a new dataset with labels predicted by running this
Transformer on the input.
Args:
dataset (Dataset): Input dataset.
Returns:
Dataset: Output dataset. `metadata` should reflect the details of
this transformation.
"""
raise NotImplementedError("'predict' is not supported for this class. "
"Perhaps you meant 'transform' or 'fit_predict' instead?")
def transform(self, dataset):
"""Return a new dataset generated by running this Transformer on the
input.
This function could return different `dataset.features`,
`dataset.labels`, or both.
Args:
dataset (Dataset): Input dataset.
Returns:
Dataset: Output dataset. `metadata` should reflect the details of
this transformation.
"""
raise NotImplementedError("'transform' is not supported for this class."
" Perhaps you meant 'predict' or 'fit_transform' instead?")
def fit_predict(self, dataset):
"""Train a model on the input and predict the labels.
Equivalent to calling `fit(dataset)` followed by `predict(dataset)`.
Args:
dataset (Dataset): Input dataset.
Returns:
Dataset: Output dataset. `metadata` should reflect the details of
this transformation.
"""
return self.fit(dataset).predict(dataset)
def fit_transform(self, dataset):
"""Train a model on the input and transform the dataset accordingly.
Equivalent to calling `fit(dataset)` followed by `transform(dataset)`.
Args:
dataset (Dataset): Input dataset.
Returns:
Dataset: Output dataset. `metadata` should reflect the details of
this transformation.
"""
return self.fit(dataset).transform(dataset)