Passed
Push — master ( 170db5...8af2aa )
by Shlomi
02:43 queued 58s
created

get_fnr_indices()   A

Complexity

Conditions 2

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 8
nop 2
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
"""
2
Post-processing fairness intervension by choosing thresholds.
3
4
There are multiple definitions for choosing the thresholds:
5
6
1. Single threshold for all the sensitive attribute values
7
   that minimizes cost.
8
2. A threshold for each sensitive attribute value
9
   that minimize cost.
10
3. A threshold for each sensitive attribute value
11
   that achieve independence and minimize cost.
12
4. A threshold for each sensitive attribute value
13
   that achieve equal FNR (equal opportunity) and minimize cost.
14
5. A threshold for each sensitive attribute value
15
   that achieve seperation (equalized odds) and minimize cost.
16
17
The code is based on `fairmlbook repository <https://github.com/fairmlbook/fairmlbook.github.io>`_.
18
19
References:
20
    - Hardt, M., Price, E., & Srebro, N. (2016).
21
      `Equality of opportunity in supervised learning
22
      <https://arxiv.org/abs/1610.02413>`_
23
      In Advances in neural information processing systems
24
      (pp. 3315-3323).
25
    - `Attacking discrimination with
26
      smarter machine learning by Google
27
      <https://research.google.com/bigpicture/attacking-discrimination-in-ml/>`_.
28
29
"""
30
31
# pylint: disable=no-name-in-module
32
33
import matplotlib.pylab as plt
34
import numpy as np
35
import pandas as pd
36
from scipy.spatial import Delaunay
37
38
from ethically.fairness.metrics.visualization import plot_roc_curves
39
40
41
def _ternary_search_float(f, left, right, tol):
42
    """Trinary search: minimize f(x) over [left, right], to within +/-tol in x.
43
44
    Works assuming f is quasiconvex.
45
46
    """
47
    while right - left > tol:
48
        left_third = (2 * left + right) / 3
49
        right_third = (left + 2 * right) / 3
50
        if f(left_third) < f(right_third):
51
            right = right_third
52
        else:
53
            left = left_third
54
    return (right + left) / 2
55
56
57
def _ternary_search_domain(f, domain):
58
    """Trinary search: minimize f(x) over a domain (sequence).
59
60
    Works assuming f is quasiconvex and domain is ascending sorted.
61
62
    """
63
    left = 0
64
    right = len(domain) - 1
65
    changed = True
66
67
    while changed and left != right:
68
69
        changed = False
70
71
        left_third = (2 * left + right) // 3
72
        right_third = (left + 2 * right) // 3
73
74
        if f(domain[left_third]) < f(domain[right_third]):
75
            right = right_third - 1
76
            changed = True
77
        else:
78
            left = left_third + 1
79
            changed = True
80
81
    return domain[(left + right) // 2]
82
83
84
def _cost_function(fpr, tpr, base_rate, cost_matrix):
85
    """Compute the cost of given (fpr, tpr).
86
87
    [[tn, fp], [fn, tp]]
88
    """
89
90
    fp = fpr * (1 - base_rate)
91
    tn = (1 - base_rate) - fp
92
    tp = tpr * base_rate
93
    fn = base_rate - tp
94
95
    conf_matrix = np.array([tn, fp, fn, tp])
96
97
    return (conf_matrix * np.array(cost_matrix).ravel()).sum()
98
99
100
def _extract_threshold(roc_curves):
101
    return next(iter(roc_curves.values()))[2]
102
103
104
def _first_index_above(array, value):
105
    """Find the smallest index i for which array[i] > value.
106
107
    If no such value exists, return len(array).
108
    """
109
    array = np.array(array)
110
    v = np.concatenate([array > value, np.ones_like(array[-1:])])
111
    return np.argmax(v, axis=0)
112
113
114
def _calc_acceptance_rate(fpr, tpr, base_rate):
115
    return 1 - ((fpr * (1 - base_rate)
116
                 + tpr * base_rate))
117
118
119
def find_single_threshold(roc_curves, base_rates, proportions,
120
                          cost_matrix):
121
    """Compute single threshold that minimizes cost.
122
123
    :param roc_curves: Receiver operating characteristic (ROC)
124
                       by attribute.
125
    :type roc_curves: dict
126
    :param base_rates: Base rate by attribute.
127
    :type base_rates: dict
128
    :param proportions: Proportion of each attribute value.
129
    :type proportions: dict
130
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
131
    :type cost_matrix: sequence
132
    :return: Threshold, FPR and TPR by attribute and cost value.
133
    :rtype: tuple
134
135
    """
136
137
    def total_cost_function(index):
138
        total_cost = 0
139
140
        for group, roc in roc_curves.items():
141
            fpr = roc[0][index]
142
            tpr = roc[1][index]
143
144
            group_cost = _cost_function(fpr, tpr,
145
                                        base_rates[group], cost_matrix)
146
            group_cost *= proportions[group]
147
148
            total_cost += group_cost
149
150
        return -total_cost
151
152
    thresholds = _extract_threshold(roc_curves)
153
154
    cutoff_index = _ternary_search_domain(total_cost_function,
155
                                          range(len(thresholds)))
156
157
    fpr_tpr = {group: (roc[0][cutoff_index], roc[1][cutoff_index])
158
               for group, roc in roc_curves.items()}
159
160
    cost = total_cost_function(cutoff_index)
161
162
    return thresholds[cutoff_index], fpr_tpr, cost
163
164
165
def find_min_cost_thresholds(roc_curves, base_rates, cost_matrix):
166
    """Compute thresholds by attribute values that minimize cost.
167
168
    :param roc_curves: Receiver operating characteristic (ROC)
169
                       by attribute.
170
    :type roc_curves: dict
171
    :param base_rates: Base rate by attribute.
172
    :type base_rates: dict
173
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
174
    :type cost_matrix: sequence
175
    :return: Thresholds, FPR and TPR by attribute and cost value.
176
    :rtype: tuple
177
178
    """
179
    # pylint: disable=cell-var-from-loop
180
181
    cutoffs = {}
182
    fpr_tpr = {}
183
184
    cost = 0
185
    thresholds = _extract_threshold(roc_curves)
186
187
    for group, roc in roc_curves.items():
188
        def group_cost_function(index):
189
            fpr = roc[0][index]
190
            tpr = roc[1][index]
191
            return -_cost_function(fpr, tpr,
192
                                   base_rates[group], cost_matrix)
193
194
        threshold_index = _ternary_search_domain(group_cost_function,
195
                                                 range(len(thresholds)))
196
197
        cutoffs[group] = thresholds[threshold_index]
198
199
        fpr_tpr[group] = (roc[0][threshold_index],
200
                          roc[1][threshold_index])
201
202
        cost += group_cost_function(threshold_index)
203
204
    return cutoffs, fpr_tpr, cost
205
206
207
def get_acceptance_rate_indices(roc_curves, base_rates,
208
                                acceptance_rate_value):
209
    indices = {}
210
    for group, roc in roc_curves.items():
211
        # can be calculated outside the function
212
        acceptance_rates = _calc_acceptance_rate(fpr=roc[0],
213
                                                 tpr=roc[1],
214
                                                 base_rate=base_rates[group])
215
216
        index = _first_index_above(acceptance_rates,
217
                                   (1 - acceptance_rate_value)) - 2
218
219
        indices[group] = index
220
221
    return indices
222
223
224 View Code Duplication
def find_independence_thresholds(roc_curves, base_rates, proportions,
225
                                 cost_matrix):
226
    """Compute thresholds that achieve independence and minimize cost.
227
228
    :param roc_curves: Receiver operating characteristic (ROC)
229
                       by attribute.
230
    :type roc_curves: dict
231
    :param base_rates: Base rate by attribute.
232
    :type base_rates: dict
233
    :param proportions: Proportion of each attribute value.
234
    :type proportions: dict
235
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
236
    :type cost_matrix: sequence
237
    :return: Thresholds, FPR and TPR by attribute and cost value.
238
    :rtype: tuple
239
240
    """
241
242
    cutoffs = {}
243
244
    def total_cost_function(acceptance_rate_value):
245
        # todo: move demo here + multiple cost
246
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
247
                                              acceptance_rate_value)
248
249
        total_cost = 0
250
251
        for group, roc in roc_curves.items():
252
            index = indices[group]
253
254
            fpr = roc[0][index]
255
            tpr = roc[1][index]
256
257
            group_cost = _cost_function(fpr, tpr,
258
                                        base_rates[group],
259
                                        cost_matrix)
260
            group_cost *= proportions[group]
261
262
            total_cost += group_cost
263
264
        return -total_cost
265
266
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
267
                                                     0, 1, 1e-3)
268
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
269
                                                    acceptance_rate_min_cost)
270
271
    thresholds = _extract_threshold(roc_curves)
272
273
    cutoffs = {group: thresholds[threshold_index]
274
               for group, threshold_index
275
               in threshold_indices.items()}
276
277
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
278
                       roc[1][threshold_indices[group]])
279
               for group, roc in roc_curves.items()}
280
281
    return cutoffs, fpr_tpr, acceptance_rate_min_cost
282
283
284
def get_fnr_indices(roc_curves, fnr_value):
285
    indices = {}
286
    for group, roc in roc_curves.items():
287
        tprs = roc[1]
288
        index = _first_index_above(1 - tprs,
289
                                   (1 - fnr_value)) - 1
290
291
        indices[group] = index
292
293
    return indices
294
295
296 View Code Duplication
def find_fnr_thresholds(roc_curves, base_rates, proportions,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
297
                        cost_matrix):
298
    """Compute thresholds that achieve equal FNRs and minimize cost.
299
300
    Also known as **equal opportunity**.
301
302
    :param roc_curves: Receiver operating characteristic (ROC)
303
                       by attribute.
304
    :type roc_curves: dict
305
    :param base_rates: Base rate by attribute.
306
    :type base_rates: dict
307
    :param proportions: Proportion of each attribute value.
308
    :type proportions: dict
309
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
310
    :type cost_matrix: sequence
311
    :return: Thresholds, FPR and TPR by attribute and cost value.
312
    :rtype: tuple
313
314
    """
315
316
    cutoffs = {}
317
318
    def total_cost_function(fnr_value):
319
        # todo: move demo here + multiple cost
320
        indices = get_fnr_indices(roc_curves, fnr_value)
321
322
        total_cost = 0
323
324
        for group, roc in roc_curves.items():
325
            index = indices[group]
326
327
            fpr = roc[0][index]
328
            tpr = roc[1][index]
329
330
            group_cost = _cost_function(fpr, tpr,
331
                                        base_rates[group],
332
                                        cost_matrix)
333
            group_cost *= proportions[group]
334
335
            total_cost += group_cost
336
337
        return -total_cost
338
339
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
340
                                               0, 1, 1e-3)
341
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
342
343
    cost = total_cost_function(fnr_value_min_cost)
344
345
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
346
                       roc[1][threshold_indices[group]])
347
               for group, roc in roc_curves.items()}
348
349
    thresholds = _extract_threshold(roc_curves)
350
    cutoffs = {group: thresholds[threshold_index]
351
               for group, threshold_index
352
               in threshold_indices.items()}
353
354
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
355
356
357
def _find_feasible_roc(roc_curves):
358
    polygons = [Delaunay(list(zip(fprs, tprs)))
359
                for group, (fprs, tprs, _) in roc_curves.items()]
360
361
    feasible_points = []
362
363
    for poly in polygons:
364
        for p in poly.points:
365
366
            if all(poly2.find_simplex(p) != -1 for poly2 in polygons):
367
                feasible_points.append(p)
368
369
    return np.array(feasible_points)
370
371
372
def find_separation_thresholds(roc_curves, base_rate, cost_matrix):
373
    """Compute thresholds that achieve separation and minimize cost.
374
375
    Also known as **equalized odds**.
376
377
    :param roc_curves: Receiver operating characteristic (ROC)
378
                       by attribute.
379
    :type roc_curves: dict
380
    :param base_rate: Overall base rate.
381
    :type base_rate: float
382
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
383
    :type cost_matrix: sequence
384
    :return: Thresholds, FPR and TPR by attribute and cost value.
385
    :rtype: tuple
386
387
    """
388
389
    feasible_points = _find_feasible_roc(roc_curves)
390
391
    cost, (best_fpr, best_tpr) = max((_cost_function(fpr, tpr, base_rate,
392
                                                     cost_matrix),
393
                                      (fpr, tpr))
394
                                     for fpr, tpr in feasible_points)
395
396
    return {}, {'': (best_fpr, best_tpr)}, cost
397
398
399
def find_thresholds(roc_curves, proportions, base_rate,
400
                    base_rates, cost_matrix,
401
                    with_single=True, with_min_cost=True,
402
                    with_independence=True, with_fnr=True,
403
                    with_separation=True):
404
    """Compute thresholds that achieve various criteria and minimize cost.
405
406
    :param roc_curves: Receiver operating characteristic (ROC)
407
                       by attribute.
408
    :type roc_curves: dict
409
    :param proportions: Proportion of each attribute value.
410
    :type proportions: dict
411
    :param base_rate: Overall base rate.
412
    :type base_rate: float
413
    :param base_rates: Base rate by attribute.
414
    :type base_rates: dict
415
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
416
    :type cost_matrix: sequence
417
418
    :param with_single: Compute single threshold.
419
    :type with_single: bool
420
    :param with_min_cost: Compute minimum cost thresholds.
421
    :type with_min_cost: bool
422
    :param with_independence: Compute independence thresholds.
423
    :type with_independence: bool
424
    :param with_fnr: Compute FNR thresholds.
425
    :type with_fnr: bool
426
    :param with_separation: Compute separation thresholds.
427
    :type with_separation: bool
428
429
    :return: Dictionary of threshold criteria,
430
             and for each criterion:
431
             thresholds, FPR and TPR by attribute and cost value.
432
    :rtype: dict
433
434
    """
435
436
    thresholds = {}
437
438
    if with_single:
439
        thresholds['single'] = find_single_threshold(roc_curves,
440
                                                     base_rates,
441
                                                     proportions,
442
                                                     cost_matrix)
443
444
    if with_min_cost:
445
        thresholds['min_cost'] = find_min_cost_thresholds(roc_curves,
446
                                                          base_rates,
447
                                                          cost_matrix)
448
449
    if with_independence:
450
        thresholds['independence'] = find_independence_thresholds(roc_curves,
451
                                                                  base_rates,
452
                                                                  proportions,
453
                                                                  cost_matrix)
454
455
    if with_fnr:
456
        thresholds['fnr'] = find_fnr_thresholds(roc_curves,
457
                                                base_rates,
458
                                                proportions,
459
                                                cost_matrix)
460
461
    if with_separation:
462
        thresholds['separation'] = find_separation_thresholds(roc_curves,
463
                                                              base_rate,
464
                                                              cost_matrix)
465
466
    return thresholds
467
468
469
def plot_roc_curves_thresholds(roc_curves, thresholds_data,
470
                               aucs=None,
471
                               title='ROC Curves by Attribute',
472
                               ax=None, figsize=None,
473
                               title_fontsize='large',
474
                               text_fontsize='medium'):
475
    """Generate the ROC curves by attribute with thresholds.
476
477
    Based on :func:`skplt.metrics.plot_roc`
478
479
    :param roc_curves: Receiver operating characteristic (ROC)
480
                       by attribute.
481
    :type roc_curves: dict
482
    :param thresholds_data: Thresholds by attribute from the
483
                            function
484
                            :func:`~ethically.interventions
485
                            .threshold.find_thresholds`.
486
    :type thresholds_data: dict
487
    :param aucs: Area Under the ROC (AUC) by attribute.
488
    :type aucs: dict
489
    :param str title: Title of the generated plot.
490
    :param ax: The axes upon which to plot the curve.
491
               If `None`, the plot is drawn on a new set of axes.
492
    :param tuple figsize: Tuple denoting figure size of the plot
493
                          e.g. (6, 6).
494
    :param title_fontsize: Matplotlib-style fontsizes.
495
                          Use e.g. 'small', 'medium', 'large'
496
                          or integer-values.
497
    :param text_fontsize: Matplotlib-style fontsizes.
498
                          Use e.g. 'small', 'medium', 'large'
499
                          or integer-values.
500
    :return: The axes on which the plot was drawn.
501
    :rtype: :class:`matplotlib.axes.Axes`
502
503
    """
504
505
    ax = plot_roc_curves(roc_curves, aucs,
506
                         title, ax, figsize, title_fontsize, text_fontsize)
507
508
    MARKERS = ['o', '^', 'x', '+', 'p']
509
510
    for (name, data), marker in zip(thresholds_data.items(), MARKERS):
511
        label = name.replace('_', ' ').title()
512
        ax.scatter(*zip(*data[1].values()),
513
                   marker=marker, color='k', label=label,
514
                   zorder=float('inf'))
515
516
    plt.legend()
517
518
    return ax
519
520
521
def plot_fpt_tpr(roc_curves,
522
                 title='FPR-TPR Curves by Attribute',
523
                 ax=None, figsize=None,
524
                 title_fontsize='large', text_fontsize='medium'):
525
    """Generate FPR and TPR curves by thresholds and by attribute.
526
527
    Based on :func:`skplt.metrics.plot_roc`
528
529
    :param roc_curves: Receiver operating characteristic (ROC)
530
                       by attribute.
531
    :type roc_curves: dict
532
    :param str title: Title of the generated plot.
533
    :param ax: The axes upon which to plot the curve.
534
               If `None`, the plot is drawn on a new set of axes.
535
    :param tuple figsize: Tuple denoting figure size of the plot
536
                          e.g. (6, 6).
537
    :param title_fontsize: Matplotlib-style fontsizes.
538
                          Use e.g. 'small', 'medium', 'large'
539
                          or integer-values.
540
    :param text_fontsize: Matplotlib-style fontsizes.
541
                          Use e.g. 'small', 'medium', 'large'
542
                          or integer-values.
543
    :return: The axes on which the plot was drawn.
544
    :rtype: :class:`matplotlib.axes.Axes`
545
546
    """
547
548
    if ax is None:
549
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
550
551
    ax.set_title(title, fontsize=title_fontsize)
552
553
    thresholds = _extract_threshold(roc_curves)
554
555
    prop_cycle = plt.rcParams['axes.prop_cycle']
556
    colors = prop_cycle.by_key()['color']
557
558
    for (group, roc), color in zip(roc_curves.items(), colors):
559
        plt.plot(thresholds, roc[0], '-',
560
                 label='{} - FPR'.format(group), color=color)
561
        plt.plot(thresholds, roc[1], '--',
562
                 label='{} - TPR'.format(group), color=color)
563
        plt.legend()
564
565
    ax.set_ylim([0.0, 1.05])
566
    ax.set_xlabel('Threshold', fontsize=text_fontsize)
567
    ax.set_ylabel('Probability', fontsize=text_fontsize)
568
    ax.tick_params(labelsize=text_fontsize)
569
    ax.legend(fontsize=text_fontsize)
570
571
    return ax
572
573
574
def plot_costs(thresholds_data,
575
               title='Cost by Threshold',
576
               ax=None, figsize=None,
577
               title_fontsize='large', text_fontsize='medium'):
578
    """Plot cost by threshold definition and by attribute.
579
580
    Based on :func:`skplt.metrics.plot_roc`
581
582
    :param thresholds_data: Thresholds by attribute from the
583
                            function
584
                            :func:`~ethically.interventions
585
                            .threshold.find_thresholds`.
586
    :type thresholds_data: dict
587
    :param str title: Title of the generated plot.
588
    :param ax: The axes upon which to plot the curve.
589
               If `None`, the plot is drawn on a new set of axes.
590
    :param tuple figsize: Tuple denoting figure size of the plot
591
                          e.g. (6, 6).
592
    :param title_fontsize: Matplotlib-style fontsizes.
593
                          Use e.g. 'small', 'medium', 'large'
594
                          or integer-values.
595
    :param text_fontsize: Matplotlib-style fontsizes.
596
                          Use e.g. 'small', 'medium', 'large'
597
                          or integer-values.
598
    :return: The axes on which the plot was drawn.
599
    :rtype: :class:`matplotlib.axes.Axes`
600
601
    """
602
603
    if ax is None:
604
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
605
606
    ax.set_title(title, fontsize=title_fontsize)
607
608
    costs = {group.replace('_', ' ').title(): cost
609
             for group, (_, _, cost, *_) in thresholds_data.items()}
610
611
    (pd.Series(costs)
612
     .sort_values(ascending=False)
613
     .plot(kind='barh', ax=ax))
614
615
    ax.set_xlabel('Cost', fontsize=text_fontsize)
616
    ax.set_ylabel('Threshold', fontsize=text_fontsize)
617
    ax.tick_params(labelsize=text_fontsize)
618
619
    return ax
620