Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
from pandas.core.dtypes.common import is_list_like | |
from sklearn.utils import check_consistent_length | |
from sklearn.utils.validation import column_or_1d | |
def check_inputs(X, y, sample_weight=None, ensure_2d=True): | |
"""Input validation for debiasing algorithms. | |
Checks all inputs for consistent length, validates shapes (optional for X), | |
and returns an array of all ones if sample_weight is ``None``. | |
Args: | |
X (array-like): Input data. | |
y (array-like, shape = (n_samples,)): Target values. | |
sample_weight (array-like, optional): Sample weights. | |
ensure_2d (bool, optional): Whether to raise a ValueError if X is not | |
2D. | |
Returns: | |
tuple: | |
* **X** (`array-like`) -- Validated X. Unchanged. | |
* **y** (`array-like`) -- Validated y. Possibly converted to 1D if | |
not a :class:`pandas.Series`. | |
* **sample_weight** (`array-like`) -- Validated sample_weight. If no | |
sample_weight is provided, returns a consistent-length array of | |
ones. | |
""" | |
if ensure_2d and X.ndim != 2: | |
raise ValueError("Expected X to be 2D, got ndim == {} instead.".format( | |
X.ndim)) | |
if not isinstance(y, pd.Series): # don't cast Series -> ndarray | |
y = column_or_1d(y) | |
if sample_weight is not None: | |
sample_weight = column_or_1d(sample_weight) | |
else: | |
sample_weight = np.ones(X.shape[0]) | |
check_consistent_length(X, y, sample_weight) | |
return X, y, sample_weight | |
def check_groups(arr, prot_attr, ensure_binary=False): | |
"""Get groups from the index of arr. | |
If there are multiple protected attributes provided, the index is flattened | |
to be a 1-D Index of tuples. If ensure_binary is ``True``, raises a | |
ValueError if there are not exactly two unique groups. Also checks that all | |
provided protected attributes are in the index. | |
Args: | |
arr (array-like): Either a Pandas object containing protected attribute | |
information in the index or array-like with explicit protected | |
attribute array(s) for `prot_attr`. | |
prot_attr (label or array-like or list of labels/arrays): Protected | |
attribute(s). If contains labels, arr must include these in its | |
index. If ``None``, all protected attributes in ``arr.index`` are | |
used. Can also be 1D array-like of the same length as arr or a | |
list of a combination of such arrays and labels in which case, arr | |
may not necessarily be a Pandas type. | |
ensure_binary (bool): Raise an error if the resultant groups are not | |
binary. | |
Returns: | |
tuple: | |
* **groups** (:class:`pandas.Index`) -- Label (or tuple of labels) | |
of protected attribute for each sample in arr. | |
* **prot_attr** (`FrozenList`) -- Modified input. If input is a | |
single label, returns single-item list. If input is ``None`` | |
returns list of all protected attributes. | |
""" | |
arr_is_pandas = isinstance(arr, (pd.DataFrame, pd.Series)) | |
if prot_attr is None: # use all protected attributes provided in arr | |
if not arr_is_pandas: | |
raise TypeError("Expected `Series` or `DataFrame` for arr, got " | |
f"{type(arr).__name__} instead. Otherwise, pass " | |
"explicit prot_attr array(s).") | |
groups = arr.index | |
elif arr_is_pandas: | |
df = arr.index.to_frame() | |
groups = df.set_index(prot_attr).index # let pandas handle errors | |
else: # arr isn't pandas. might be okay if prot_attr is array-like | |
df = pd.DataFrame(index=[None]*len(arr)) # dummy to check lengths match | |
try: | |
groups = df.set_index(prot_attr).index | |
except KeyError as e: | |
raise TypeError("arr does not include protected attributes in the " | |
"index. Check if this got dropped or prot_attr is " | |
"formatted incorrectly.") from e | |
prot_attr = groups.names | |
groups = groups.to_flat_index() | |
n_unique = groups.nunique() | |
if ensure_binary and n_unique != 2: | |
raise ValueError("Expected 2 protected attribute groups, got " | |
f"{groups.unique() if n_unique > 5 else n_unique}") | |
return groups, prot_attr | |