Spaces:
Runtime error
Runtime error
File size: 2,066 Bytes
d2a8669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import numpy as np
from aif360.datasets import BinaryLabelDataset
from aif360.algorithms import Transformer
class ARTClassifier(Transformer):
"""Wraps an instance of an :obj:`art.classifiers.Classifier` to extend
:obj:`~aif360.algorithms.Transformer`.
"""
def __init__(self, art_classifier):
"""Initialize ARTClassifier.
Args:
art_classifier (art.classifier.Classifier): A Classifier
object from the `adversarial-robustness-toolbox`_.
.. _adversarial-robustness-toolbox:
https://github.com/Trusted-AI/adversarial-robustness-toolbox
"""
super(ARTClassifier, self).__init__(art_classifier=art_classifier)
self._art_classifier = art_classifier
def fit(self, dataset, batch_size=128, nb_epochs=20):
"""Train a classifer on the input.
Args:
dataset (Dataset): Training dataset.
batch_size (int): Size of batches (passed through to ART).
nb_epochs (int): Number of epochs to use for training (passed
through to ART).
Returns:
ARTClassifier: Returns self.
"""
self._art_classifier.fit(dataset.features, dataset.labels,
batch_size=batch_size, nb_epochs=nb_epochs)
return self
def predict(self, dataset, logits=False):
"""Perform prediction for the input.
Args:
dataset (Dataset): Test dataset.
logits (bool, optional): True is prediction should be done at the
logits layer (passed through to ART).
Returns:
Dataset: Dataset with predicted labels in the `labels` field.
"""
pred_labels = self._art_classifier.predict(dataset.features,
dataset.labels, logits=logits)
if isinstance(dataset, BinaryLabelDataset):
pred_labels = np.argmax(pred_labels, axis=1).reshape((-1, 1))
pred_dataset = dataset.copy()
pred_dataset.labels = pred_labels
return pred_dataset
|