FairUP / src /aif360 /sklearn /datasets /tempeh_datasets.py
erasmopurif's picture
First commit
d2a8669
import pandas as pd
try:
import tempeh.configurations as tc
except ImportError as error:
from logging import warning
warning("{}: fetch_lawschool_gpa will be unavailable. To install, run:\n"
"pip install 'aif360[LawSchoolGPA]'".format(error))
from aif360.sklearn.datasets.utils import standardize_dataset
def fetch_lawschool_gpa(subset="all", *, usecols=None, dropcols=None,
numeric_only=False, dropna=True):
"""Load the Law School GPA dataset
Note:
By default, the data is downloaded from tempeh. See
https://github.com/microsoft/tempeh for details.
Args:
subset ({'train', 'test', or 'all'}, optional): Select the dataset to
load: 'train' for the training set, 'test' for the test set, 'all'
for both.
usecols (single label or list-like, optional): Feature column(s) to
keep. All others are dropped.
dropcols (single label or list-like, optional): Feature column(s) to
drop.
numeric_only (bool): Drop all non-numeric feature columns.
dropna (bool): Drop rows with NAs. FIXME: NAs already dropped by tempeh
Returns:
namedtuple: Tuple containing X, y, and sample_weights for the Law School
GPA dataset accessible by index or name.
"""
if subset not in {'train', 'test', 'all'}:
raise ValueError("subset must be either 'train', 'test', or 'all'; "
"cannot be {}".format(subset))
dataset = tc.datasets["lawschool_gpa"]()
X_train, X_test = dataset.get_X(format=pd.DataFrame)
y_train, y_test = dataset.get_y(format=pd.Series)
A_train, A_test = dataset.get_sensitive_features(name='race',
format=pd.Series)
all_train = pd.concat([X_train, y_train, A_train], axis=1)
all_test = pd.concat([X_test, y_test, A_test], axis=1)
if subset == "train":
df = all_train
elif subset == "test":
df = all_test
else:
df = pd.concat([all_train, all_test], axis=0)
df.race = df.race.astype('category').cat.set_categories(
['black', 'white'], ordered=True)
return standardize_dataset(df, prot_attr='race', target='zfygpa',
usecols=usecols, dropcols=dropcols,
numeric_only=numeric_only, dropna=dropna)