Spaces:
Runtime error
Runtime error
#' Classification Metric | |
#' @description | |
#' Class for computing metrics based on two BinaryLabelDatasets. The first dataset is the original one and the second is the output of the classification transformer (or similar) | |
#' @param dataset (BinaryLabelDataset) Dataset containing ground-truth labels | |
#' @param classified_dataset (BinaryLabelDataset) Dataset containing predictions | |
#' @param privileged_groups Privileged groups. List containing privileged protected attribute name and value of the privileged protected attribute. | |
#' @param unprivileged_groups Unprivileged groups. List containing unprivileged protected attribute name and value of the unprivileged protected attribute. | |
#' @usage | |
#' classification_metric(dataset, classified_dataset, unprivileged_groups, privileged_groups) | |
#' @examples | |
#' \dontrun{ | |
#' load_aif360_lib() | |
#' # Input dataset | |
#' data <- data.frame("feat" = c(0,0,1,1,1,1,0,1,1,0), "label" = c(1,0,0,1,0,0,1,0,1,1)) | |
#' # Create aif compatible input dataset | |
#' act <- aif360::binary_label_dataset(data_path = data, favor_label=0, unfavor_label=1, | |
#' unprivileged_protected_attribute=0, | |
#' privileged_protected_attribute=1, | |
#' target_column="label", protected_attribute="feat") | |
#' # Classified dataset | |
#' pred_data <- data.frame("feat" = c(0,0,1,1,1,1,0,1,1,0), "label" = c(1,0,1,1,1,0,1,0,0,1)) | |
#' # Create aif compatible classified dataset | |
#' pred <- aif360::binary_label_dataset(data_path = pred_data, favor_label=0, unfavor_label=1, | |
#' unprivileged_protected_attribute=0, | |
#' privileged_protected_attribute=1, | |
#' target_column="label", protected_attribute="feat") | |
#' # Create an instance of classification metric | |
#' cm <- classification_metric(act, pred, list('feat', 1), list('feat', 0)) | |
#' # Access metric functions | |
#' cm$accuracy() | |
#' } | |
#' @seealso | |
#' \href{https://aif360.readthedocs.io/en/latest/modules/metrics.html#classification-metric}{Explore available classification metrics explanations here} | |
#' | |
#' Available metrics: | |
#' \itemize{ | |
#' \item accuracy | |
#' \item average_abs_odds_difference | |
#' \item average_odds_difference | |
#' \item between_all_groups_coefficient_of_variation | |
#' \item between_all_groups_generalized_entropy_index | |
#' \item between_all_groups_theil_index | |
#' \item between_group_coefficient_of_variation | |
#' \item between_group_generalized_entropy_index | |
#' \item between_group_theil_index | |
#' \item binary_confusion_matrix | |
#' \item coefficient_of_variation | |
#' \item disparate_impact | |
#' \item equal_opportunity_difference | |
#' \item error_rate | |
#' \item error_rate_difference | |
#' \item error_rate_ratio | |
#' \item false_discovery_rate | |
#' \item false_discovery_rate_difference | |
#' \item false_discovery_rate_ratio | |
#' \item false_negative_rate | |
#' \item false_negative_rate_difference | |
#' \item false_negative_rate_ratio | |
#' \item false_omission_rate | |
#' \item false_omission_rate_difference | |
#' \item false_omission_rate_ratio | |
#' \item false_positive_rate | |
#' \item false_positive_rate_difference | |
#' \item false_positive_rate_ratio | |
#' \item generalized_binary_confusion_matrix | |
#' \item generalized_entropy_index | |
#' \item generalized_false_negative_rate | |
#' \item generalized_false_positive_rate | |
#' \item generalized_true_negative_rate | |
#' \item generalized_true_positive_rate | |
#' \item negative_predictive_value | |
#' \item num_false_negatives | |
#' \item num_false_positives | |
#' \item num_generalized_false_negatives | |
#' \item num_generalized_false_positives | |
#' \item num_generalized_true_negatives | |
#' \item num_generalized_true_positives | |
#' \item num_pred_negatives | |
#' \item num_pred_positives | |
#' \item num_true_negatives | |
#' \item num_true_positives | |
#' \item performance_measures | |
#' \item positive_predictive_value | |
#' \item power | |
#' \item precision | |
#' \item recall | |
#' \item selection_rate | |
#' \item sensitivity | |
#' \item specificity | |
#' \item statistical_parity_difference | |
#' \item theil_index | |
#' \item true_negative_rate | |
#' \item true_positive_rate | |
#' \item true_positive_rate_difference | |
#' | |
#' } | |
#' @export | |
#' @importFrom reticulate py_suppress_warnings | |
#' | |
classification_metric <- function(dataset, | |
classified_dataset, | |
unprivileged_groups, | |
privileged_groups){ | |
u_dict <- dict_fn(unprivileged_groups) | |
p_dict <- dict_fn(privileged_groups) | |
return(metrics$ClassificationMetric(dataset, | |
classified_dataset, | |
unprivileged_groups = u_dict, | |
privileged_groups = p_dict)) | |
} | |