erasmopurif's picture
First commit
d2a8669
import numpy as np
from aif360.datasets import BinaryLabelDataset
from aif360.algorithms import Transformer
class ARTClassifier(Transformer):
"""Wraps an instance of an :obj:`art.classifiers.Classifier` to extend
:obj:`~aif360.algorithms.Transformer`.
"""
def __init__(self, art_classifier):
"""Initialize ARTClassifier.
Args:
art_classifier (art.classifier.Classifier): A Classifier
object from the `adversarial-robustness-toolbox`_.
.. _adversarial-robustness-toolbox:
https://github.com/Trusted-AI/adversarial-robustness-toolbox
"""
super(ARTClassifier, self).__init__(art_classifier=art_classifier)
self._art_classifier = art_classifier
def fit(self, dataset, batch_size=128, nb_epochs=20):
"""Train a classifer on the input.
Args:
dataset (Dataset): Training dataset.
batch_size (int): Size of batches (passed through to ART).
nb_epochs (int): Number of epochs to use for training (passed
through to ART).
Returns:
ARTClassifier: Returns self.
"""
self._art_classifier.fit(dataset.features, dataset.labels,
batch_size=batch_size, nb_epochs=nb_epochs)
return self
def predict(self, dataset, logits=False):
"""Perform prediction for the input.
Args:
dataset (Dataset): Test dataset.
logits (bool, optional): True is prediction should be done at the
logits layer (passed through to ART).
Returns:
Dataset: Dataset with predicted labels in the `labels` field.
"""
pred_labels = self._art_classifier.predict(dataset.features,
dataset.labels, logits=logits)
if isinstance(dataset, BinaryLabelDataset):
pred_labels = np.argmax(pred_labels, axis=1).reshape((-1, 1))
pred_dataset = dataset.copy()
pred_dataset.labels = pred_labels
return pred_dataset