Spaces:
Runtime error
Runtime error
File size: 4,249 Bytes
d2a8669 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#' Reject option classification
#'
#' @description Reject option classification is a postprocessing technique that gives
#' favorable outcomes to unpriviliged groups and unfavorable outcomes to
#' priviliged groups in a confidence band around the decision boundary with
#' the highest uncertainty.
#' @param unprivileged_groups A list epresentation for unprivileged group.
#' @param privileged_groups A list representation for privileged group.
#' @param low_class_thresh Smallest classification threshold to use in the optimization. Should be between 0. and 1.
#' @param high_class_thresh Highest classification threshold to use in the optimization. Should be between 0. and 1.
#' @param num_class_thresh Number of classification thresholds between low_class_thresh and high_class_thresh for the optimization search. Should be > 0.
#' @param num_ROC_margin Number of relevant ROC margins to be used in the optimization search. Should be > 0.
#' @param metric_name Name of the metric to use for the optimization. Allowed options are "Statistical parity difference", "Average odds difference", "Equal opportunity difference".
#' @param metric_ub Upper bound of constraint on the metric value
#' @param metric_lb Lower bound of constraint on the metric value
#' @examples
#' \dontrun{
#' # Example with Adult Dataset
#' load_aif360_lib()
#' ad <- adult_dataset()
#' p <- list("race",1)
#' u <- list("race", 0)
#'
#' col_names <- c(ad$feature_names, "label")
#' ad_df <- data.frame(ad$features, ad$labels)
#' colnames(ad_df) <- col_names
#'
#' lr <- glm(label ~ ., data=ad_df, family=binomial)
#'
#' ad_prob <- predict(lr, ad_df)
#' ad_pred <- factor(ifelse(ad_prob> 0.5,1,0))
#'
#' ad_df_pred <- data.frame(ad_df)
#' ad_df_pred$label <- as.character(ad_pred)
#' colnames(ad_df_pred) <- c(ad$feature_names, 'label')
#'
#' ad_ds <- binary_label_dataset(ad_df, target_column='label', favor_label = 1,
#' unfavor_label = 0, unprivileged_protected_attribute = 0,
#' privileged_protected_attribute = 1, protected_attribute='race')
#'
#' ad_ds_pred <- binary_label_dataset(ad_df_pred, target_column='label', favor_label = 1,
#' unfavor_label = 0, unprivileged_protected_attribute = 0,
#' privileged_protected_attribute = 1, protected_attribute='race')
#'
#' roc <- reject_option_classification(unprivileged_groups = u,
#' privileged_groups = p,
#' low_class_thresh = 0.01,
#' high_class_thresh = 0.99,
#' num_class_thresh = as.integer(100),
#' num_ROC_margin = as.integer(50),
#' metric_name = "Statistical parity difference",
#' metric_ub = 0.05,
#' metric_lb = -0.05)
#'
#' roc <- roc$fit(ad_ds, ad_ds_pred)
#'
#' ds_transformed_pred <- roc$predict(ad_ds_pred)
#' }
#' @export
#'
reject_option_classification <- function(unprivileged_groups,
privileged_groups,
low_class_thresh=0.01,
high_class_thresh=0.99,
num_class_thresh=as.integer(100),
num_ROC_margin=as.integer(50),
metric_name='Statistical parity difference',
metric_ub=0.05,
metric_lb=-0.05){
u_dict <- dict_fn(unprivileged_groups)
p_dict <- dict_fn(privileged_groups)
return(post_algo$RejectOptionClassification(u_dict,
p_dict,
low_class_thresh,
high_class_thresh,
num_class_thresh,
num_ROC_margin,
metric_name,
metric_ub,
metric_lb))
}
|