from abc import abstractmethod from functools import wraps from aif360.datasets import Dataset from aif360.decorating_metaclass import ApplyDecorator # TODO: Use sklearn.exceptions.NotFittedError instead? class NotFittedError(ValueError, AttributeError): """Error to be raised if `predict` or `transform` is called before `fit`.""" def addmetadata(func): """Decorator for instance methods which perform a transformation and return a new dataset. Automatically populates the `metadata` field of the new dataset to reflect details of the transformation that occurred, e.g.:: { 'transformer': 'TransformerClass.function_name', 'params': kwargs_from_init, 'previous': [all_datasets_used_by_func] } """ @wraps(func) def wrapper(self, *args, **kwargs): new_dataset = func(self, *args, **kwargs) if isinstance(new_dataset, Dataset): new_dataset.metadata = new_dataset.metadata.copy() new_dataset.metadata.update({ 'transformer': '{}.{}'.format(type(self).__name__, func.__name__), 'params': self._params, 'previous': [a for a in args if isinstance(a, Dataset)] }) return new_dataset return wrapper BaseClass = ApplyDecorator(addmetadata) class Transformer(BaseClass): """Abstract base class for transformers. Transformers are an abstraction for any process which acts on a :obj:`Dataset` and returns a new, modified Dataset. This definition encompasses pre-processing, in-processing, and post-processing algorithms. """ @abstractmethod def __init__(self, **kwargs): """Initialize a Transformer object. Algorithm-specific configuration parameters should be passed here. """ self._params = kwargs def fit(self, dataset): """Train a model on the input. Args: dataset (Dataset): Input dataset. Returns: Transformer: Returns self. """ return self def predict(self, dataset): """Return a new dataset with labels predicted by running this Transformer on the input. Args: dataset (Dataset): Input dataset. Returns: Dataset: Output dataset. `metadata` should reflect the details of this transformation. """ raise NotImplementedError("'predict' is not supported for this class. " "Perhaps you meant 'transform' or 'fit_predict' instead?") def transform(self, dataset): """Return a new dataset generated by running this Transformer on the input. This function could return different `dataset.features`, `dataset.labels`, or both. Args: dataset (Dataset): Input dataset. Returns: Dataset: Output dataset. `metadata` should reflect the details of this transformation. """ raise NotImplementedError("'transform' is not supported for this class." " Perhaps you meant 'predict' or 'fit_transform' instead?") def fit_predict(self, dataset): """Train a model on the input and predict the labels. Equivalent to calling `fit(dataset)` followed by `predict(dataset)`. Args: dataset (Dataset): Input dataset. Returns: Dataset: Output dataset. `metadata` should reflect the details of this transformation. """ return self.fit(dataset).predict(dataset) def fit_transform(self, dataset): """Train a model on the input and transform the dataset accordingly. Equivalent to calling `fit(dataset)` followed by `transform(dataset)`. Args: dataset (Dataset): Input dataset. Returns: Dataset: Output dataset. `metadata` should reflect the details of this transformation. """ return self.fit(dataset).transform(dataset)