Spaces:
Runtime error
Runtime error
File size: 2,273 Bytes
d2a8669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import numpy as np
from aif360.datasets import StructuredDataset
'''
Multiclass supports the multiple values in the favorable and unfavorable label's
'''
class MulticlassLabelDataset(StructuredDataset):
"""Base class for all structured datasets with multiclass labels."""
def __init__(self, favorable_label = [1.], unfavorable_label = [0.], **kwargs):
"""
Args:
favorable_label (list): Label value which is considered favorable
(i.e. "positive").
unfavorable_label (list): Label value which is considered
unfavorable (i.e. "negative").
**kwargs: StructuredDataset arguments.
"""
self.favorable_label = favorable_label
self.unfavorable_label = unfavorable_label
super(MulticlassLabelDataset, self).__init__(**kwargs)
def validate_dataset(self):
"""Error checking and type validation.
Raises:
ValueError: `labels` must be shape [n, 1].
ValueError: `favorable_label` and `unfavorable_label` must be the
only values present in `labels`.
"""
# fix scores before validating
if np.all(self.scores == self.labels):
for i in range(0,len(self.scores)):
if self.scores[i] in self.favorable_label:
self.scores[i] = float(1)
else:
self.scores[i] = float(0)
super(MulticlassLabelDataset, self).validate_dataset()
# =========================== SHAPE CHECKING ===========================
# Verify if the labels are only 1 column
if self.labels.shape[1] != 1:
raise ValueError("MulticlassLabelDataset only supports single-column "
"labels:\n\tlabels.shape = {}".format(self.labels.shape))
# =========================== VALUE CHECKING ===========================
# Check if the favorable and unfavorable labels match those in the dataset
if (not set(self.labels.ravel()) <=
set(self.favorable_label + (self.unfavorable_label))):
raise ValueError("The favorable and unfavorable labels provided do "
"not match the labels in the dataset.")
|