erasmopurif's picture
First commit
d2a8669
import numpy as np
import pandas as pd
from pandas.core.dtypes.common import is_list_like
from sklearn.utils import check_consistent_length
from sklearn.utils.validation import column_or_1d
def check_inputs(X, y, sample_weight=None, ensure_2d=True):
"""Input validation for debiasing algorithms.
Checks all inputs for consistent length, validates shapes (optional for X),
and returns an array of all ones if sample_weight is ``None``.
Args:
X (array-like): Input data.
y (array-like, shape = (n_samples,)): Target values.
sample_weight (array-like, optional): Sample weights.
ensure_2d (bool, optional): Whether to raise a ValueError if X is not
2D.
Returns:
tuple:
* **X** (`array-like`) -- Validated X. Unchanged.
* **y** (`array-like`) -- Validated y. Possibly converted to 1D if
not a :class:`pandas.Series`.
* **sample_weight** (`array-like`) -- Validated sample_weight. If no
sample_weight is provided, returns a consistent-length array of
ones.
"""
if ensure_2d and X.ndim != 2:
raise ValueError("Expected X to be 2D, got ndim == {} instead.".format(
X.ndim))
if not isinstance(y, pd.Series): # don't cast Series -> ndarray
y = column_or_1d(y)
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight)
else:
sample_weight = np.ones(X.shape[0])
check_consistent_length(X, y, sample_weight)
return X, y, sample_weight
def check_groups(arr, prot_attr, ensure_binary=False):
"""Get groups from the index of arr.
If there are multiple protected attributes provided, the index is flattened
to be a 1-D Index of tuples. If ensure_binary is ``True``, raises a
ValueError if there are not exactly two unique groups. Also checks that all
provided protected attributes are in the index.
Args:
arr (array-like): Either a Pandas object containing protected attribute
information in the index or array-like with explicit protected
attribute array(s) for `prot_attr`.
prot_attr (label or array-like or list of labels/arrays): Protected
attribute(s). If contains labels, arr must include these in its
index. If ``None``, all protected attributes in ``arr.index`` are
used. Can also be 1D array-like of the same length as arr or a
list of a combination of such arrays and labels in which case, arr
may not necessarily be a Pandas type.
ensure_binary (bool): Raise an error if the resultant groups are not
binary.
Returns:
tuple:
* **groups** (:class:`pandas.Index`) -- Label (or tuple of labels)
of protected attribute for each sample in arr.
* **prot_attr** (`FrozenList`) -- Modified input. If input is a
single label, returns single-item list. If input is ``None``
returns list of all protected attributes.
"""
arr_is_pandas = isinstance(arr, (pd.DataFrame, pd.Series))
if prot_attr is None: # use all protected attributes provided in arr
if not arr_is_pandas:
raise TypeError("Expected `Series` or `DataFrame` for arr, got "
f"{type(arr).__name__} instead. Otherwise, pass "
"explicit prot_attr array(s).")
groups = arr.index
elif arr_is_pandas:
df = arr.index.to_frame()
groups = df.set_index(prot_attr).index # let pandas handle errors
else: # arr isn't pandas. might be okay if prot_attr is array-like
df = pd.DataFrame(index=[None]*len(arr)) # dummy to check lengths match
try:
groups = df.set_index(prot_attr).index
except KeyError as e:
raise TypeError("arr does not include protected attributes in the "
"index. Check if this got dropped or prot_attr is "
"formatted incorrectly.") from e
prot_attr = groups.names
groups = groups.to_flat_index()
n_unique = groups.nunique()
if ensure_binary and n_unique != 2:
raise ValueError("Expected 2 protected attribute groups, got "
f"{groups.unique() if n_unique > 5 else n_unique}")
return groups, prot_attr