Passed
Pull Request — master (#14)
by Shlomi
04:42 queued 02:29
created

plot_roc_curves_thresholds()   A

Complexity

Conditions 2

Size

Total Lines 50
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 16
nop 8
dl 0
loc 50
rs 9.6
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Post-processing fairness intervension by choosing thresholds.
3
4
There are multiple definitions for choosing the thresholds:
5
6
1. Single threshold for all the sensitive attribute values
7
   that minimizes cost.
8
2. A threshold for each sensitive attribute value
9
   that minimize cost.
10
3. A threshold for each sensitive attribute value
11
   that achieve independence and minimize cost.
12
4. A threshold for each sensitive attribute value
13
   that achieve equal FNR (equal opportunity) and minimize cost.
14
5. A threshold for each sensitive attribute value
15
   that achieve seperation (equalized odds) and minimize cost.
16
17
The code is based on `fairmlbook repository <https://github.com/fairmlbook/fairmlbook.github.io>`_.
18
19
References:
20
    - Hardt, M., Price, E., & Srebro, N. (2016).
21
      Equality of opportunity in supervised learning.
22
      In Advances in neural information processing systems
23
      (pp. 3315-3323).
24
    - `Attacking discrimination with
25
      smarter machine learning by Google
26
      <https://research.google.com/bigpicture/attacking-discrimination-in-ml/>`_.
27
28
"""
29
30
# pylint: disable=no-name-in-module
31
32
import matplotlib.pylab as plt
33
import numpy as np
34
import pandas as pd
35
from scipy.spatial import Delaunay
36
37
from ethically.fairness.metrics.visualization import plot_roc_curves
38
39
40
def _ternary_search_float(f, left, right, tol):
41
    """Trinary search: minimize f(x) over [left, right], to within +/-tol in x.
42
43
    Works assuming f is quasiconvex.
44
45
    """
46
    while right - left > tol:
47
        left_third = (2 * left + right) / 3
48
        right_third = (left + 2 * right) / 3
49
        if f(left_third) < f(right_third):
50
            right = right_third
51
        else:
52
            left = left_third
53
    return (right + left) / 2
54
55
56
def _ternary_search_domain(f, domain):
57
    """Trinary search: minimize f(x) over a domain (sequence).
58
59
    Works assuming f is quasiconvex and domain is ascending sorted.
60
61
    """
62
    left = 0
63
    right = len(domain) - 1
64
    changed = True
65
66
    while changed and left != right:
67
68
        changed = False
69
70
        left_third = (2 * left + right) // 3
71
        right_third = (left + 2 * right) // 3
72
73
        if f(domain[left_third]) < f(domain[right_third]):
74
            right = right_third - 1
75
            changed = True
76
        else:
77
            left = left_third + 1
78
            changed = True
79
80
    return domain[(left + right) // 2]
81
82
83
def _cost_function(fpr, tpr, base_rate, cost_matrix):
84
    """Compute the cost of given (fpr, tpr).
85
86
    [[tn, fp], [fn, tp]]
87
    """
88
89
    fp = fpr * (1 - base_rate)
90
    tn = (1 - base_rate) - fp
91
    tp = tpr * base_rate
92
    fn = base_rate - tp
93
94
    conf_matrix = np.array([tn, fp, fn, tp])
95
96
    return (conf_matrix * np.array(cost_matrix).ravel()).sum()
97
98
99
def _extract_threshold(roc_curves):
100
    return next(iter(roc_curves.values()))[2]
101
102
103
def _first_index_above(array, value):
104
    """Find the smallest index i for which array[i] > value.
105
106
    If no such value exists, return len(array).
107
    """
108
    array = np.array(array)
109
    v = np.concatenate([array > value, np.ones_like(array[-1:])])
110
    return np.argmax(v, axis=0)
111
112
113
def _calc_acceptance_rate(fpr, tpr, base_rate):
114
    return 1 - ((fpr * (1 - base_rate)
115
                 + tpr * base_rate))
116
117
118
def find_single_threshold(roc_curves, base_rates, proportions,
119
                          cost_matrix):
120
    """Compute single threshold that minimizes cost.
121
122
    :param roc_curves: Receiver operating characteristic (ROC)
123
                       by attribute.
124
    :type roc_curves: dict
125
    :param base_rates: Base rate by attribute.
126
    :type base_rates: dict
127
    :param proportions: Proportion of each attribute value.
128
    :type proportions: dict
129
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
130
    :type cost_matrix: sequence
131
    :return: Threshold, FPR and TPR by attribute and cost value.
132
    :rtype: tuple
133
134
    """
135
136
    def total_cost_function(index):
137
        total_cost = 0
138
139
        for group, roc in roc_curves.items():
140
            fpr = roc[0][index]
141
            tpr = roc[1][index]
142
143
            group_cost = _cost_function(fpr, tpr,
144
                                        base_rates[group], cost_matrix)
145
            group_cost *= proportions[group]
146
147
            total_cost += group_cost
148
149
        return -total_cost
150
151
    thresholds = _extract_threshold(roc_curves)
152
153
    cutoff_index = _ternary_search_domain(total_cost_function,
154
                                          range(len(thresholds)))
155
156
    fpr_tpr = {group: (roc[0][cutoff_index], roc[1][cutoff_index])
157
               for group, roc in roc_curves.items()}
158
159
    cost = total_cost_function(cutoff_index)
160
161
    return thresholds[cutoff_index], fpr_tpr, cost
162
163
164
def find_min_cost_thresholds(roc_curves, base_rates, cost_matrix):
165
    """Compute thresholds by attribute values that minimize cost.
166
167
    :param roc_curves: Receiver operating characteristic (ROC)
168
                       by attribute.
169
    :type roc_curves: dict
170
    :param base_rates: Base rate by attribute.
171
    :type base_rates: dict
172
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
173
    :type cost_matrix: sequence
174
    :return: Thresholds, FPR and TPR by attribute and cost value.
175
    :rtype: tuple
176
177
    """
178
    # pylint: disable=cell-var-from-loop
179
180
    cutoffs = {}
181
    fpr_tpr = {}
182
183
    cost = 0
184
    thresholds = _extract_threshold(roc_curves)
185
186
    for group, roc in roc_curves.items():
187
        def group_cost_function(index):
188
            fpr = roc[0][index]
189
            tpr = roc[1][index]
190
            return -_cost_function(fpr, tpr,
191
                                   base_rates[group], cost_matrix)
192
193
        threshold_index = _ternary_search_domain(group_cost_function,
194
                                                 range(len(thresholds)))
195
196
        cutoffs[group] = thresholds[threshold_index]
197
198
        fpr_tpr[group] = (roc[0][threshold_index],
199
                          roc[1][threshold_index])
200
201
        cost += group_cost_function(threshold_index)
202
203
    return cutoffs, fpr_tpr, cost
204
205
206
def get_acceptance_rate_indices(roc_curves, base_rates,
207
                                acceptance_rate_value):
208
    indices = {}
209
    for group, roc in roc_curves.items():
210
        # can be calculated outside the function
211
        acceptance_rates = _calc_acceptance_rate(fpr=roc[0],
212
                                                 tpr=roc[1],
213
                                                 base_rate=base_rates[group])
214
215
        index = _first_index_above(acceptance_rates,
216
                                   (1 - acceptance_rate_value)) - 2
217
218
        indices[group] = index
219
220
    return indices
221
222
223 View Code Duplication
def find_independence_thresholds(roc_curves, base_rates, proportions,
224
                                 cost_matrix):
225
    """Compute thresholds that achieve independence and minimize cost.
226
227
    :param roc_curves: Receiver operating characteristic (ROC)
228
                       by attribute.
229
    :type roc_curves: dict
230
    :param base_rates: Base rate by attribute.
231
    :type base_rates: dict
232
    :param proportions: Proportion of each attribute value.
233
    :type proportions: dict
234
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
235
    :type cost_matrix: sequence
236
    :return: Thresholds, FPR and TPR by attribute and cost value.
237
    :rtype: tuple
238
239
    """
240
241
    cutoffs = {}
242
243
    def total_cost_function(acceptance_rate_value):
244
        # todo: move demo here + multiple cost
245
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
246
                                              acceptance_rate_value)
247
248
        total_cost = 0
249
250
        for group, roc in roc_curves.items():
251
            index = indices[group]
252
253
            fpr = roc[0][index]
254
            tpr = roc[1][index]
255
256
            group_cost = _cost_function(fpr, tpr,
257
                                        base_rates[group],
258
                                        cost_matrix)
259
            group_cost *= proportions[group]
260
261
            total_cost += group_cost
262
263
        return -total_cost
264
265
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
266
                                                     0, 1, 1e-3)
267
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
268
                                                    acceptance_rate_min_cost)
269
270
    thresholds = _extract_threshold(roc_curves)
271
272
    cutoffs = {group: thresholds[threshold_index]
273
               for group, threshold_index
274
               in threshold_indices.items()}
275
276
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
277
                       roc[1][threshold_indices[group]])
278
               for group, roc in roc_curves.items()}
279
280
    return cutoffs, fpr_tpr, acceptance_rate_min_cost
281
282
283
def get_fnr_indices(roc_curves, fnr_value):
284
    indices = {}
285
    for group, roc in roc_curves.items():
286
        tprs = roc[1]
287
        index = _first_index_above(1 - tprs,
288
                                   (1 - fnr_value)) - 1
289
290
        indices[group] = index
291
292
    return indices
293
294
295 View Code Duplication
def find_fnr_thresholds(roc_curves, base_rates, proportions,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
296
                        cost_matrix):
297
    """Compute thresholds that achieve equal FNRs and minimize cost.
298
299
    Also known as **equal opportunity**.
300
301
    :param roc_curves: Receiver operating characteristic (ROC)
302
                       by attribute.
303
    :type roc_curves: dict
304
    :param base_rates: Base rate by attribute.
305
    :type base_rates: dict
306
    :param proportions: Proportion of each attribute value.
307
    :type proportions: dict
308
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
309
    :type cost_matrix: sequence
310
    :return: Thresholds, FPR and TPR by attribute and cost value.
311
    :rtype: tuple
312
313
    """
314
315
    cutoffs = {}
316
317
    def total_cost_function(fnr_value):
318
        # todo: move demo here + multiple cost
319
        indices = get_fnr_indices(roc_curves, fnr_value)
320
321
        total_cost = 0
322
323
        for group, roc in roc_curves.items():
324
            index = indices[group]
325
326
            fpr = roc[0][index]
327
            tpr = roc[1][index]
328
329
            group_cost = _cost_function(fpr, tpr,
330
                                        base_rates[group],
331
                                        cost_matrix)
332
            group_cost *= proportions[group]
333
334
            total_cost += group_cost
335
336
        return -total_cost
337
338
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
339
                                               0, 1, 1e-3)
340
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
341
342
    cost = total_cost_function(fnr_value_min_cost)
343
344
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
345
                       roc[1][threshold_indices[group]])
346
               for group, roc in roc_curves.items()}
347
348
    thresholds = _extract_threshold(roc_curves)
349
    cutoffs = {group: thresholds[threshold_index]
350
               for group, threshold_index
351
               in threshold_indices.items()}
352
353
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
354
355
356
def _find_feasible_roc(roc_curves):
357
    polygons = [Delaunay(list(zip(fprs, tprs)))
358
                for group, (fprs, tprs, _) in roc_curves.items()]
359
360
    feasible_points = []
361
362
    for poly in polygons:
363
        for p in poly.points:
364
365
            if all(poly2.find_simplex(p) != -1 for poly2 in polygons):
366
                feasible_points.append(p)
367
368
    return np.array(feasible_points)
369
370
371
def find_separation_thresholds(roc_curves, base_rate, cost_matrix):
372
    """Compute thresholds that achieve separation and minimize cost.
373
374
    Also known as **equalized odds**.
375
376
    :param roc_curves: Receiver operating characteristic (ROC)
377
                       by attribute.
378
    :type roc_curves: dict
379
    :param base_rate: Overall base rate.
380
    :type base_rate: float
381
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
382
    :type cost_matrix: sequence
383
    :return: Thresholds, FPR and TPR by attribute and cost value.
384
    :rtype: tuple
385
386
    """
387
388
    feasible_points = _find_feasible_roc(roc_curves)
389
390
    cost, (best_fpr, best_tpr) = max((_cost_function(fpr, tpr, base_rate,
391
                                                     cost_matrix),
392
                                      (fpr, tpr))
393
                                     for fpr, tpr in feasible_points)
394
395
    return {}, {'': (best_fpr, best_tpr)}, cost
396
397
398
def find_thresholds(roc_curves, proportions, base_rate,
399
                    base_rates, cost_matrix,
400
                    with_single=True, with_min_cost=True,
401
                    with_independence=True, with_fnr=True,
402
                    with_separation=True):
403
    """Compute thresholds that achieve various criteria and minimize cost.
404
405
    :param roc_curves: Receiver operating characteristic (ROC)
406
                       by attribute.
407
    :type roc_curves: dict
408
    :param proportions: Proportion of each attribute value.
409
    :type proportions: dict
410
    :param base_rate: Overall base rate.
411
    :type base_rate: float
412
    :param base_rates: Base rate by attribute.
413
    :type base_rates: dict
414
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
415
    :type cost_matrix: sequence
416
417
    :param with_single: Compute single threshold.
418
    :type with_single: bool
419
    :param with_min_cost: Compute minimum cost thresholds.
420
    :type with_min_cost: bool
421
    :param with_independence: Compute independence thresholds.
422
    :type with_independence: bool
423
    :param with_fnr: Compute FNR thresholds.
424
    :type with_fnr: bool
425
    :param with_separation: Compute separation thresholds.
426
    :type with_separation: bool
427
428
    :return: Dictionary of threshold criteria,
429
             and for each criterion:
430
             thresholds, FPR and TPR by attribute and cost value.
431
    :rtype: dict
432
433
    """
434
435
    thresholds = {}
436
437
    if with_single:
438
        thresholds['single'] = find_single_threshold(roc_curves,
439
                                                     base_rates,
440
                                                     proportions,
441
                                                     cost_matrix)
442
443
    if with_min_cost:
444
        thresholds['min_cost'] = find_min_cost_thresholds(roc_curves,
445
                                                          base_rates,
446
                                                          cost_matrix)
447
448
    if with_independence:
449
        thresholds['independence'] = find_independence_thresholds(roc_curves,
450
                                                                  base_rates,
451
                                                                  proportions,
452
                                                                  cost_matrix)
453
454
    if with_fnr:
455
        thresholds['fnr'] = find_fnr_thresholds(roc_curves,
456
                                                base_rates,
457
                                                proportions,
458
                                                cost_matrix)
459
460
    if with_separation:
461
        thresholds['separation'] = find_separation_thresholds(roc_curves,
462
                                                              base_rate,
463
                                                              cost_matrix)
464
465
    return thresholds
466
467
468
def plot_roc_curves_thresholds(roc_curves, thresholds_data,
469
                               aucs=None,
470
                               title='ROC Curves by Attribute',
471
                               ax=None, figsize=None,
472
                               title_fontsize='large',
473
                               text_fontsize='medium'):
474
    """Generate the ROC curves by attribute with thresholds.
475
476
    Based on :func:`skplt.metrics.plot_roc`
477
478
    :param roc_curves: Receiver operating characteristic (ROC)
479
                       by attribute.
480
    :type roc_curves: dict
481
    :param thresholds_data: Thresholds by attribute from the
482
                            function
483
                            :func:`~ethically.interventions
484
                            .threshold.find_thresholds`.
485
    :type thresholds_data: dict
486
    :param aucs: Area Under the ROC (AUC) by attribute.
487
    :type aucs: dict
488
    :param str title: Title of the generated plot.
489
    :param ax: The axes upon which to plot the curve.
490
               If `None`, the plot is drawn on a new set of axes.
491
    :param tuple figsize: Tuple denoting figure size of the plot
492
                          e.g. (6, 6).
493
    :param title_fontsize: Matplotlib-style fontsizes.
494
                          Use e.g. 'small', 'medium', 'large'
495
                          or integer-values.
496
    :param text_fontsize: Matplotlib-style fontsizes.
497
                          Use e.g. 'small', 'medium', 'large'
498
                          or integer-values.
499
    :return: The axes on which the plot was drawn.
500
    :rtype: :class:`matplotlib.axes.Axes`
501
502
    """
503
504
    ax = plot_roc_curves(roc_curves, aucs,
505
                         title, ax, figsize, title_fontsize, text_fontsize)
506
507
    MARKERS = ['o', '^', 'x', '+', 'p']
508
509
    for (name, data), marker in zip(thresholds_data.items(), MARKERS):
510
        label = name.replace('_', ' ').title()
511
        ax.scatter(*zip(*data[1].values()),
512
                   marker=marker, color='k', label=label,
513
                   zorder=float('inf'))
514
515
    plt.legend()
516
517
    return ax
518
519
520
def plot_fpt_tpr(roc_curves,
521
                 title='FPR-TPR Curves by Attribute',
522
                 ax=None, figsize=None,
523
                 title_fontsize='large', text_fontsize='medium'):
524
    """Generate FPR and TPR curves by thresholds and by attribute.
525
526
    Based on :func:`skplt.metrics.plot_roc`
527
528
    :param roc_curves: Receiver operating characteristic (ROC)
529
                       by attribute.
530
    :type roc_curves: dict
531
    :param str title: Title of the generated plot.
532
    :param ax: The axes upon which to plot the curve.
533
               If `None`, the plot is drawn on a new set of axes.
534
    :param tuple figsize: Tuple denoting figure size of the plot
535
                          e.g. (6, 6).
536
    :param title_fontsize: Matplotlib-style fontsizes.
537
                          Use e.g. 'small', 'medium', 'large'
538
                          or integer-values.
539
    :param text_fontsize: Matplotlib-style fontsizes.
540
                          Use e.g. 'small', 'medium', 'large'
541
                          or integer-values.
542
    :return: The axes on which the plot was drawn.
543
    :rtype: :class:`matplotlib.axes.Axes`
544
545
    """
546
547
    if ax is None:
548
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
549
550
    ax.set_title(title, fontsize=title_fontsize)
551
552
    thresholds = _extract_threshold(roc_curves)
553
554
    prop_cycle = plt.rcParams['axes.prop_cycle']
555
    colors = prop_cycle.by_key()['color']
556
557
    for (group, roc), color in zip(roc_curves.items(), colors):
558
        plt.plot(thresholds, roc[0], '-',
559
                 label='{} - FPR'.format(group), color=color)
560
        plt.plot(thresholds, roc[1], '--',
561
                 label='{} - TPR'.format(group), color=color)
562
        plt.legend()
563
564
    ax.set_ylim([0.0, 1.05])
565
    ax.set_xlabel('Threshold', fontsize=text_fontsize)
566
    ax.set_ylabel('Probability', fontsize=text_fontsize)
567
    ax.tick_params(labelsize=text_fontsize)
568
    ax.legend(fontsize=text_fontsize)
569
570
    return ax
571
572
573
def plot_costs(thresholds_data,
574
               title='Cost by Threshold',
575
               ax=None, figsize=None,
576
               title_fontsize='large', text_fontsize='medium'):
577
    """Plot cost by threshold definition and by attribute.
578
579
    Based on :func:`skplt.metrics.plot_roc`
580
581
    :param thresholds_data: Thresholds by attribute from the
582
                            function
583
                            :func:`~ethically.interventions
584
                            .threshold.find_thresholds`.
585
    :type thresholds_data: dict
586
    :param str title: Title of the generated plot.
587
    :param ax: The axes upon which to plot the curve.
588
               If `None`, the plot is drawn on a new set of axes.
589
    :param tuple figsize: Tuple denoting figure size of the plot
590
                          e.g. (6, 6).
591
    :param title_fontsize: Matplotlib-style fontsizes.
592
                          Use e.g. 'small', 'medium', 'large'
593
                          or integer-values.
594
    :param text_fontsize: Matplotlib-style fontsizes.
595
                          Use e.g. 'small', 'medium', 'large'
596
                          or integer-values.
597
    :return: The axes on which the plot was drawn.
598
    :rtype: :class:`matplotlib.axes.Axes`
599
600
    """
601
602
    if ax is None:
603
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
604
605
    ax.set_title(title, fontsize=title_fontsize)
606
607
    costs = {group.replace('_', ' ').title(): cost
608
             for group, (_, _, cost, *_) in thresholds_data.items()}
609
610
    (pd.Series(costs)
611
     .sort_values(ascending=False)
612
     .plot(kind='barh', ax=ax))
613
614
    ax.set_xlabel('Cost', fontsize=text_fontsize)
615
    ax.set_ylabel('Threshold', fontsize=text_fontsize)
616
    ax.tick_params(labelsize=text_fontsize)
617
618
    return ax
619