File size: 12,449 Bytes
d2a8669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# Copyright 2019 Seth V. Neel, Michael J. Kearns, Aaron L. Roth, Zhiwei Steven Wu
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.
"""Class GerryFairClassifier implementing the 'FairFictPlay' Algorithm of [KRNW18].

This module contains functionality to instantiate, fit, and predict
using the FairFictPlay algorithm of:
https://arxiv.org/abs/1711.05144
It also contains the ability to audit arbitrary classifiers for
rich subgroup unfairness, where rich subgroups are defined by hyperplanes
over the sensitive attributes. This iteration of the codebase supports hyperplanes, trees,
kernel methods, and support vector machines. For usage examples refer to examples/gerry_plots.ipynb
"""


import copy
from aif360.algorithms.inprocessing.gerryfair import heatmap
from aif360.algorithms.inprocessing.gerryfair.clean import array_to_tuple
from aif360.algorithms.inprocessing.gerryfair.learner import Learner
from aif360.algorithms.inprocessing.gerryfair.auditor import *
from aif360.algorithms.inprocessing.gerryfair.classifier_history import ClassifierHistory
from aif360.algorithms import Transformer


class GerryFairClassifier(Transformer):
    """Model is an algorithm for learning classifiers that are fair with respect
    to rich subgroups.

    Rich subgroups are defined by (linear) functions over the sensitive
    attributes, and fairness notions are statistical: false positive, false
    negative, and statistical parity rates. This implementation uses a max of
    two regressions as a cost-sensitive classification oracle, and supports
    linear regression, support vector machines, decision trees, and kernel
    regression. For details see:

    References:
        .. [1] "Preventing Fairness Gerrymandering: Auditing and Learning for
           Subgroup Fairness." Michale Kearns, Seth Neel, Aaron Roth, Steven Wu.
           ICML '18.
        .. [2] "An Empirical Study of Rich Subgroup Fairness for Machine
           Learning". Michael Kearns, Seth Neel, Aaron Roth, Steven Wu. FAT '19.
    """
    def __init__(self, C=10, printflag=False, heatmapflag=False,
                 heatmap_iter=10, heatmap_path='.', max_iters=10, gamma=0.01,
                 fairness_def='FP', predictor=linear_model.LinearRegression()):
        """Initialize Model Object and set hyperparameters.

        Args:
            C: Maximum L1 Norm for the Dual Variables (hyperparameter)
            printflag: Print Output Flag
            heatmapflag: Save Heatmaps every heatmap_iter Flag
            heatmap_iter: Save Heatmaps every heatmap_iter
            heatmap_path: Save Heatmaps path
            max_iters: Time Horizon for the fictitious play dynamic.
            gamma: Fairness Approximation Paramater
            fairness_def: Fairness notion, FP, FN, SP.
            errors: see fit()
            fairness_violations: see fit()
            predictor: Hypothesis class for the Learner. Supports LR, SVM, KR,
                Trees.
        """

        super(GerryFairClassifier, self).__init__()
        self.C = C
        self.printflag = printflag
        self.heatmapflag = heatmapflag
        self.heatmap_iter = heatmap_iter
        self.heatmap_path = heatmap_path
        self.max_iters = max_iters
        self.gamma = gamma
        self.fairness_def = fairness_def
        self.predictor = predictor
        self.classifiers = None
        self.errors = None
        self.fairness_violations = None
        if self.fairness_def not in ['FP', 'FN']:
            raise Exception(
                'This metric is not yet supported for learning. Metric specified: {}.'
                .format(self.fairness_def))

    def fit(self, dataset, early_termination=True):
        """Run Fictitious play to compute the approximately fair classifier.

        Args:
            dataset: dataset object with its own class definition in datasets
                folder inherits from class StandardDataset.
            early_termination: Terminate Early if Auditor can't find fairness
                violation of more than gamma.
        Returns:
            Self
        """

        # defining variables and data structures for algorithm
        X, X_prime, y = clean.extract_df_from_ds(dataset)
        learner = Learner(X, y, self.predictor)
        auditor = Auditor(dataset, self.fairness_def)
        history = ClassifierHistory()

        # initialize variables
        n = X.shape[0]
        costs_0, costs_1, X_0 = auditor.initialize_costs(n)
        metric_baseline = 0
        predictions = [0.0] * n

        # scaling variables for heatmap
        vmin = None
        vmax = None

        # print output variables
        errors = []
        fairness_violations = []

        iteration = 1
        while iteration < self.max_iters:
            # learner's best response: solve the CSC problem, get mixture decisions on X to update prediction probabilities
            history.append_classifier(learner.best_response(costs_0, costs_1))
            error, predictions = learner.generate_predictions(
                history.get_most_recent(), predictions, iteration)
            # auditor's best response: find group, update costs
            metric_baseline = auditor.get_baseline(y, predictions)
            group = auditor.get_group(predictions, metric_baseline)
            costs_0, costs_1 = auditor.update_costs(costs_0, costs_1, group,
                                                    self.C, iteration,
                                                    self.gamma)

            # outputs
            errors.append(error)
            fairness_violations.append(group.weighted_disparity)
            self.print_outputs(iteration, error, group)
            vmin, vmax = self.save_heatmap(
                iteration, dataset,
                history.get_most_recent().predict(X), vmin, vmax)
            iteration += 1

            # early termination:
            if early_termination and (len(errors) >= 5) and (
                    (errors[-1] == errors[-2]) or fairness_violations[-1] == fairness_violations[-2]) and \
                    fairness_violations[-1] < self.gamma:
                iteration = self.max_iters

        self.classifiers = history.classifiers
        self.errors = errors
        self.fairness_violations = fairness_violations
        return self

    def predict(self, dataset, threshold=.5):
        """Return dataset object where labels are the predictions returned by
        the fitted model.

        Args:
            dataset: dataset object with its own class definition in datasets
                folder inherits from class StandardDataset.
            threshold: The positive prediction cutoff for the soft-classifier.

        Returns:
            dataset_new: modified dataset object where the labels attribute are
            the predictions returned by the self model
        """

        # Generates predictions.
        dataset_new = copy.deepcopy(dataset)
        data, _, _ = clean.extract_df_from_ds(dataset_new)
        num_classifiers = len(self.classifiers)
        y_hat = None
        for hyp in self.classifiers:
            new_predictions = hyp.predict(data)/num_classifiers
            if y_hat is None:
                y_hat = new_predictions
            else:
                y_hat = np.add(y_hat, new_predictions)
        if threshold:
            dataset_new.labels = np.asarray(
                [1 if y >= threshold else 0 for y in y_hat])
        else:
            dataset_new.labels = np.asarray([y for y in y_hat])
        # ensure ndarray is formatted correctly
        dataset_new.labels.resize(dataset.labels.shape, refcheck=True)
        return dataset_new

    def print_outputs(self, iteration, error, group):
        """Helper function to print outputs at each iteration of fit.

        Args:
            iteration: current iter
            error: most recent error
            group: most recent group found by the auditor
        """

        if self.printflag:
            print(
                'iteration: {}, error: {}, fairness violation: {}, violated group size: {}'
                .format(int(iteration), error, group.weighted_disparity,
                        group.group_size))

    def save_heatmap(self, iteration, dataset, predictions, vmin, vmax):
        """Helper Function to save the heatmap.

        Args:
            iteration: current iteration
            dataset: dataset object with its own class definition in datasets
                folder inherits from class StandardDataset.
            predictions: predictions of the model self on dataset.
            vmin: see documentation of heatmap.py heat_map function
            vmax: see documentation of heatmap.py heat_map function

        Returns:
            (vmin, vmax)
        """

        X, X_prime, y = clean.extract_df_from_ds(dataset)
        # save heatmap every heatmap_iter iterations or the last iteration
        if (self.heatmapflag and (iteration % self.heatmap_iter) == 0):
            # initial heat map
            X_prime_heat = X_prime.iloc[:, 0:2]
            eta = 0.1
            minmax = heatmap.heat_map(
                X, X_prime_heat, y, predictions, eta,
                self.heatmap_path + '/heatmap_iteration_{}'.format(iteration),
                vmin, vmax)
            if iteration == 1:
                vmin = minmax[0]
                vmax = minmax[1]
        return vmin, vmax

    def generate_heatmap(self,
                         dataset,
                         predictions,
                         vmin=None,
                         vmax=None,
                         cols_index=[0, 1],
                         eta=.1):
        """Helper Function to generate the heatmap at the current time.

        Args:
            iteration:current iteration
            dataset: dataset object with its own class definition in datasets
                folder inherits from class StandardDataset.
            predictions: predictions of the model self on dataset.
            vmin: see documentation of heatmap.py heat_map function
            vmax: see documentation of heatmap.py heat_map function
        """

        X, X_prime, y = clean.extract_df_from_ds(dataset)
        # save heatmap every heatmap_iter iterations or the last iteration
        X_prime_heat = X_prime.iloc[:, cols_index]
        minmax = heatmap.heat_map(X, X_prime_heat, y, predictions, eta,
                                  self.heatmap_path, vmin, vmax)

    def pareto(self, dataset, gamma_list):
        """Assumes Model has FP specified for metric. Trains for each value of
        gamma, returns error, FP (via training), and FN (via auditing) values.

        Args:
            dataset: dataset object with its own class definition in datasets
                folder inherits from class StandardDataset.
            gamma_list: the list of gamma values to generate the pareto curve

        Returns:
            list of errors, list of fp violations of those models, list of fn
            violations of those models
        """

        C = self.C
        max_iters = self.max_iters

        # Store errors and fp over time for each gamma

        # change var names, but no real dependence on FP logic
        all_errors = []
        all_fp_violations = []
        all_fn_violations = []
        self.C = C
        self.max_iters = max_iters

        auditor = Auditor(dataset, 'FN')
        for g in gamma_list:
            self.gamma = g
            fitted_model = self.fit(dataset, early_termination=True)
            errors, fairness_violations = fitted_model.errors, fitted_model.fairness_violations
            predictions = array_to_tuple((self.predict(dataset)).labels)
            _, fn_violation = auditor.audit(predictions)
            all_errors.append(errors[-1])
            all_fp_violations.append(fairness_violations[-1])
            all_fn_violations.append(fn_violation)

        return all_errors, all_fp_violations, all_fn_violations