Passed
Push — master ( 170db5...8af2aa )
by Shlomi
02:43 queued 58s
created

find_thresholds()   B

Complexity

Conditions 6

Size

Total Lines 68
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 30
nop 10
dl 0
loc 68
rs 8.2266
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Post-processing fairness intervension by choosing thresholds.
3
4
There are multiple definitions for choosing the thresholds:
5
6
1. Single threshold for all the sensitive attribute values
7
   that minimizes cost.
8
2. A threshold for each sensitive attribute value
9
   that minimize cost.
10
3. A threshold for each sensitive attribute value
11
   that achieve independence and minimize cost.
12
4. A threshold for each sensitive attribute value
13
   that achieve equal FNR (equal opportunity) and minimize cost.
14
5. A threshold for each sensitive attribute value
15
   that achieve seperation (equalized odds) and minimize cost.
16
17
The code is based on `fairmlbook repository <https://github.com/fairmlbook/fairmlbook.github.io>`_.
18
19
References:
20
    - Hardt, M., Price, E., & Srebro, N. (2016).
21
      `Equality of opportunity in supervised learning
22
      <https://arxiv.org/abs/1610.02413>`_
23
      In Advances in neural information processing systems
24
      (pp. 3315-3323).
25
    - `Attacking discrimination with
26
      smarter machine learning by Google
27
      <https://research.google.com/bigpicture/attacking-discrimination-in-ml/>`_.
28
29
"""
30
31
# pylint: disable=no-name-in-module
32
33
import matplotlib.pylab as plt
34
import numpy as np
35
import pandas as pd
36
from scipy.spatial import Delaunay
37
38
from ethically.fairness.metrics.visualization import plot_roc_curves
39
40
41
def _ternary_search_float(f, left, right, tol):
42
    """Trinary search: minimize f(x) over [left, right], to within +/-tol in x.
43
44
    Works assuming f is quasiconvex.
45
46
    """
47
    while right - left > tol:
48
        left_third = (2 * left + right) / 3
49
        right_third = (left + 2 * right) / 3
50
        if f(left_third) < f(right_third):
51
            right = right_third
52
        else:
53
            left = left_third
54
    return (right + left) / 2
55
56
57
def _ternary_search_domain(f, domain):
58
    """Trinary search: minimize f(x) over a domain (sequence).
59
60
    Works assuming f is quasiconvex and domain is ascending sorted.
61
62
    """
63
    left = 0
64
    right = len(domain) - 1
65
    changed = True
66
67
    while changed and left != right:
68
69
        changed = False
70
71
        left_third = (2 * left + right) // 3
72
        right_third = (left + 2 * right) // 3
73
74
        if f(domain[left_third]) < f(domain[right_third]):
75
            right = right_third - 1
76
            changed = True
77
        else:
78
            left = left_third + 1
79
            changed = True
80
81
    return domain[(left + right) // 2]
82
83
84
def _cost_function(fpr, tpr, base_rate, cost_matrix):
85
    """Compute the cost of given (fpr, tpr).
86
87
    [[tn, fp], [fn, tp]]
88
    """
89
90
    fp = fpr * (1 - base_rate)
91
    tn = (1 - base_rate) - fp
92
    tp = tpr * base_rate
93
    fn = base_rate - tp
94
95
    conf_matrix = np.array([tn, fp, fn, tp])
96
97
    return (conf_matrix * np.array(cost_matrix).ravel()).sum()
98
99
100
def _extract_threshold(roc_curves):
101
    return next(iter(roc_curves.values()))[2]
102
103
104
def _first_index_above(array, value):
105
    """Find the smallest index i for which array[i] > value.
106
107
    If no such value exists, return len(array).
108
    """
109
    array = np.array(array)
110
    v = np.concatenate([array > value, np.ones_like(array[-1:])])
111
    return np.argmax(v, axis=0)
112
113
114
def _calc_acceptance_rate(fpr, tpr, base_rate):
115
    return 1 - ((fpr * (1 - base_rate)
116
                 + tpr * base_rate))
117
118
119
def find_single_threshold(roc_curves, base_rates, proportions,
120
                          cost_matrix):
121
    """Compute single threshold that minimizes cost.
122
123
    :param roc_curves: Receiver operating characteristic (ROC)
124
                       by attribute.
125
    :type roc_curves: dict
126
    :param base_rates: Base rate by attribute.
127
    :type base_rates: dict
128
    :param proportions: Proportion of each attribute value.
129
    :type proportions: dict
130
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
131
    :type cost_matrix: sequence
132
    :return: Threshold, FPR and TPR by attribute and cost value.
133
    :rtype: tuple
134
135
    """
136
137
    def total_cost_function(index):
138
        total_cost = 0
139
140
        for group, roc in roc_curves.items():
141
            fpr = roc[0][index]
142
            tpr = roc[1][index]
143
144
            group_cost = _cost_function(fpr, tpr,
145
                                        base_rates[group], cost_matrix)
146
            group_cost *= proportions[group]
147
148
            total_cost += group_cost
149
150
        return -total_cost
151
152
    thresholds = _extract_threshold(roc_curves)
153
154
    cutoff_index = _ternary_search_domain(total_cost_function,
155
                                          range(len(thresholds)))
156
157
    fpr_tpr = {group: (roc[0][cutoff_index], roc[1][cutoff_index])
158
               for group, roc in roc_curves.items()}
159
160
    cost = total_cost_function(cutoff_index)
161
162
    return thresholds[cutoff_index], fpr_tpr, cost
163
164
165
def find_min_cost_thresholds(roc_curves, base_rates, cost_matrix):
166
    """Compute thresholds by attribute values that minimize cost.
167
168
    :param roc_curves: Receiver operating characteristic (ROC)
169
                       by attribute.
170
    :type roc_curves: dict
171
    :param base_rates: Base rate by attribute.
172
    :type base_rates: dict
173
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
174
    :type cost_matrix: sequence
175
    :return: Thresholds, FPR and TPR by attribute and cost value.
176
    :rtype: tuple
177
178
    """
179
    # pylint: disable=cell-var-from-loop
180
181
    cutoffs = {}
182
    fpr_tpr = {}
183
184
    cost = 0
185
    thresholds = _extract_threshold(roc_curves)
186
187
    for group, roc in roc_curves.items():
188
        def group_cost_function(index):
189
            fpr = roc[0][index]
190
            tpr = roc[1][index]
191
            return -_cost_function(fpr, tpr,
192
                                   base_rates[group], cost_matrix)
193
194
        threshold_index = _ternary_search_domain(group_cost_function,
195
                                                 range(len(thresholds)))
196
197
        cutoffs[group] = thresholds[threshold_index]
198
199
        fpr_tpr[group] = (roc[0][threshold_index],
200
                          roc[1][threshold_index])
201
202
        cost += group_cost_function(threshold_index)
203
204
    return cutoffs, fpr_tpr, cost
205
206
207
def get_acceptance_rate_indices(roc_curves, base_rates,
208
                                acceptance_rate_value):
209
    indices = {}
210
    for group, roc in roc_curves.items():
211
        # can be calculated outside the function
212
        acceptance_rates = _calc_acceptance_rate(fpr=roc[0],
213
                                                 tpr=roc[1],
214
                                                 base_rate=base_rates[group])
215
216
        index = _first_index_above(acceptance_rates,
217
                                   (1 - acceptance_rate_value)) - 2
218
219
        indices[group] = index
220
221
    return indices
222
223
224 View Code Duplication
def find_independence_thresholds(roc_curves, base_rates, proportions,
225
                                 cost_matrix):
226
    """Compute thresholds that achieve independence and minimize cost.
227
228
    :param roc_curves: Receiver operating characteristic (ROC)
229
                       by attribute.
230
    :type roc_curves: dict
231
    :param base_rates: Base rate by attribute.
232
    :type base_rates: dict
233
    :param proportions: Proportion of each attribute value.
234
    :type proportions: dict
235
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
236
    :type cost_matrix: sequence
237
    :return: Thresholds, FPR and TPR by attribute and cost value.
238
    :rtype: tuple
239
240
    """
241
242
    cutoffs = {}
243
244
    def total_cost_function(acceptance_rate_value):
245
        # todo: move demo here + multiple cost
246
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
247
                                              acceptance_rate_value)
248
249
        total_cost = 0
250
251
        for group, roc in roc_curves.items():
252
            index = indices[group]
253
254
            fpr = roc[0][index]
255
            tpr = roc[1][index]
256
257
            group_cost = _cost_function(fpr, tpr,
258
                                        base_rates[group],
259
                                        cost_matrix)
260
            group_cost *= proportions[group]
261
262
            total_cost += group_cost
263
264
        return -total_cost
265
266
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
267
                                                     0, 1, 1e-3)
268
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
269
                                                    acceptance_rate_min_cost)
270
271
    thresholds = _extract_threshold(roc_curves)
272
273
    cutoffs = {group: thresholds[threshold_index]
274
               for group, threshold_index
275
               in threshold_indices.items()}
276
277
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
278
                       roc[1][threshold_indices[group]])
279
               for group, roc in roc_curves.items()}
280
281
    return cutoffs, fpr_tpr, acceptance_rate_min_cost
282
283
284
def get_fnr_indices(roc_curves, fnr_value):
285
    indices = {}
286
    for group, roc in roc_curves.items():
287
        tprs = roc[1]
288
        index = _first_index_above(1 - tprs,
289
                                   (1 - fnr_value)) - 1
290
291
        indices[group] = index
292
293
    return indices
294
295
296 View Code Duplication
def find_fnr_thresholds(roc_curves, base_rates, proportions,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
297
                        cost_matrix):
298
    """Compute thresholds that achieve equal FNRs and minimize cost.
299
300
    Also known as **equal opportunity**.
301
302
    :param roc_curves: Receiver operating characteristic (ROC)
303
                       by attribute.
304
    :type roc_curves: dict
305
    :param base_rates: Base rate by attribute.
306
    :type base_rates: dict
307
    :param proportions: Proportion of each attribute value.
308
    :type proportions: dict
309
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
310
    :type cost_matrix: sequence
311
    :return: Thresholds, FPR and TPR by attribute and cost value.
312
    :rtype: tuple
313
314
    """
315
316
    cutoffs = {}
317
318
    def total_cost_function(fnr_value):
319
        # todo: move demo here + multiple cost
320
        indices = get_fnr_indices(roc_curves, fnr_value)
321
322
        total_cost = 0
323
324
        for group, roc in roc_curves.items():
325
            index = indices[group]
326
327
            fpr = roc[0][index]
328
            tpr = roc[1][index]
329
330
            group_cost = _cost_function(fpr, tpr,
331
                                        base_rates[group],
332
                                        cost_matrix)
333
            group_cost *= proportions[group]
334
335
            total_cost += group_cost
336
337
        return -total_cost
338
339
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
340
                                               0, 1, 1e-3)
341
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
342
343
    cost = total_cost_function(fnr_value_min_cost)
344
345
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
346
                       roc[1][threshold_indices[group]])
347
               for group, roc in roc_curves.items()}
348
349
    thresholds = _extract_threshold(roc_curves)
350
    cutoffs = {group: thresholds[threshold_index]
351
               for group, threshold_index
352
               in threshold_indices.items()}
353
354
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
355
356
357
def _find_feasible_roc(roc_curves):
358
    polygons = [Delaunay(list(zip(fprs, tprs)))
359
                for group, (fprs, tprs, _) in roc_curves.items()]
360
361
    feasible_points = []
362
363
    for poly in polygons:
364
        for p in poly.points:
365
366
            if all(poly2.find_simplex(p) != -1 for poly2 in polygons):
367
                feasible_points.append(p)
368
369
    return np.array(feasible_points)
370
371
372
def find_separation_thresholds(roc_curves, base_rate, cost_matrix):
373
    """Compute thresholds that achieve separation and minimize cost.
374
375
    Also known as **equalized odds**.
376
377
    :param roc_curves: Receiver operating characteristic (ROC)
378
                       by attribute.
379
    :type roc_curves: dict
380
    :param base_rate: Overall base rate.
381
    :type base_rate: float
382
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
383
    :type cost_matrix: sequence
384
    :return: Thresholds, FPR and TPR by attribute and cost value.
385
    :rtype: tuple
386
387
    """
388
389
    feasible_points = _find_feasible_roc(roc_curves)
390
391
    cost, (best_fpr, best_tpr) = max((_cost_function(fpr, tpr, base_rate,
392
                                                     cost_matrix),
393
                                      (fpr, tpr))
394
                                     for fpr, tpr in feasible_points)
395
396
    return {}, {'': (best_fpr, best_tpr)}, cost
397
398
399
def find_thresholds(roc_curves, proportions, base_rate,
400
                    base_rates, cost_matrix,
401
                    with_single=True, with_min_cost=True,
402
                    with_independence=True, with_fnr=True,
403
                    with_separation=True):
404
    """Compute thresholds that achieve various criteria and minimize cost.
405
406
    :param roc_curves: Receiver operating characteristic (ROC)
407
                       by attribute.
408
    :type roc_curves: dict
409
    :param proportions: Proportion of each attribute value.
410
    :type proportions: dict
411
    :param base_rate: Overall base rate.
412
    :type base_rate: float
413
    :param base_rates: Base rate by attribute.
414
    :type base_rates: dict
415
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
416
    :type cost_matrix: sequence
417
418
    :param with_single: Compute single threshold.
419
    :type with_single: bool
420
    :param with_min_cost: Compute minimum cost thresholds.
421
    :type with_min_cost: bool
422
    :param with_independence: Compute independence thresholds.
423
    :type with_independence: bool
424
    :param with_fnr: Compute FNR thresholds.
425
    :type with_fnr: bool
426
    :param with_separation: Compute separation thresholds.
427
    :type with_separation: bool
428
429
    :return: Dictionary of threshold criteria,
430
             and for each criterion:
431
             thresholds, FPR and TPR by attribute and cost value.
432
    :rtype: dict
433
434
    """
435
436
    thresholds = {}
437
438
    if with_single:
439
        thresholds['single'] = find_single_threshold(roc_curves,
440
                                                     base_rates,
441
                                                     proportions,
442
                                                     cost_matrix)
443
444
    if with_min_cost:
445
        thresholds['min_cost'] = find_min_cost_thresholds(roc_curves,
446
                                                          base_rates,
447
                                                          cost_matrix)
448
449
    if with_independence:
450
        thresholds['independence'] = find_independence_thresholds(roc_curves,
451
                                                                  base_rates,
452
                                                                  proportions,
453
                                                                  cost_matrix)
454
455
    if with_fnr:
456
        thresholds['fnr'] = find_fnr_thresholds(roc_curves,
457
                                                base_rates,
458
                                                proportions,
459
                                                cost_matrix)
460
461
    if with_separation:
462
        thresholds['separation'] = find_separation_thresholds(roc_curves,
463
                                                              base_rate,
464
                                                              cost_matrix)
465
466
    return thresholds
467
468
469
def plot_roc_curves_thresholds(roc_curves, thresholds_data,
470
                               aucs=None,
471
                               title='ROC Curves by Attribute',
472
                               ax=None, figsize=None,
473
                               title_fontsize='large',
474
                               text_fontsize='medium'):
475
    """Generate the ROC curves by attribute with thresholds.
476
477
    Based on :func:`skplt.metrics.plot_roc`
478
479
    :param roc_curves: Receiver operating characteristic (ROC)
480
                       by attribute.
481
    :type roc_curves: dict
482
    :param thresholds_data: Thresholds by attribute from the
483
                            function
484
                            :func:`~ethically.interventions
485
                            .threshold.find_thresholds`.
486
    :type thresholds_data: dict
487
    :param aucs: Area Under the ROC (AUC) by attribute.
488
    :type aucs: dict
489
    :param str title: Title of the generated plot.
490
    :param ax: The axes upon which to plot the curve.
491
               If `None`, the plot is drawn on a new set of axes.
492
    :param tuple figsize: Tuple denoting figure size of the plot
493
                          e.g. (6, 6).
494
    :param title_fontsize: Matplotlib-style fontsizes.
495
                          Use e.g. 'small', 'medium', 'large'
496
                          or integer-values.
497
    :param text_fontsize: Matplotlib-style fontsizes.
498
                          Use e.g. 'small', 'medium', 'large'
499
                          or integer-values.
500
    :return: The axes on which the plot was drawn.
501
    :rtype: :class:`matplotlib.axes.Axes`
502
503
    """
504
505
    ax = plot_roc_curves(roc_curves, aucs,
506
                         title, ax, figsize, title_fontsize, text_fontsize)
507
508
    MARKERS = ['o', '^', 'x', '+', 'p']
509
510
    for (name, data), marker in zip(thresholds_data.items(), MARKERS):
511
        label = name.replace('_', ' ').title()
512
        ax.scatter(*zip(*data[1].values()),
513
                   marker=marker, color='k', label=label,
514
                   zorder=float('inf'))
515
516
    plt.legend()
517
518
    return ax
519
520
521
def plot_fpt_tpr(roc_curves,
522
                 title='FPR-TPR Curves by Attribute',
523
                 ax=None, figsize=None,
524
                 title_fontsize='large', text_fontsize='medium'):
525
    """Generate FPR and TPR curves by thresholds and by attribute.
526
527
    Based on :func:`skplt.metrics.plot_roc`
528
529
    :param roc_curves: Receiver operating characteristic (ROC)
530
                       by attribute.
531
    :type roc_curves: dict
532
    :param str title: Title of the generated plot.
533
    :param ax: The axes upon which to plot the curve.
534
               If `None`, the plot is drawn on a new set of axes.
535
    :param tuple figsize: Tuple denoting figure size of the plot
536
                          e.g. (6, 6).
537
    :param title_fontsize: Matplotlib-style fontsizes.
538
                          Use e.g. 'small', 'medium', 'large'
539
                          or integer-values.
540
    :param text_fontsize: Matplotlib-style fontsizes.
541
                          Use e.g. 'small', 'medium', 'large'
542
                          or integer-values.
543
    :return: The axes on which the plot was drawn.
544
    :rtype: :class:`matplotlib.axes.Axes`
545
546
    """
547
548
    if ax is None:
549
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
550
551
    ax.set_title(title, fontsize=title_fontsize)
552
553
    thresholds = _extract_threshold(roc_curves)
554
555
    prop_cycle = plt.rcParams['axes.prop_cycle']
556
    colors = prop_cycle.by_key()['color']
557
558
    for (group, roc), color in zip(roc_curves.items(), colors):
559
        plt.plot(thresholds, roc[0], '-',
560
                 label='{} - FPR'.format(group), color=color)
561
        plt.plot(thresholds, roc[1], '--',
562
                 label='{} - TPR'.format(group), color=color)
563
        plt.legend()
564
565
    ax.set_ylim([0.0, 1.05])
566
    ax.set_xlabel('Threshold', fontsize=text_fontsize)
567
    ax.set_ylabel('Probability', fontsize=text_fontsize)
568
    ax.tick_params(labelsize=text_fontsize)
569
    ax.legend(fontsize=text_fontsize)
570
571
    return ax
572
573
574
def plot_costs(thresholds_data,
575
               title='Cost by Threshold',
576
               ax=None, figsize=None,
577
               title_fontsize='large', text_fontsize='medium'):
578
    """Plot cost by threshold definition and by attribute.
579
580
    Based on :func:`skplt.metrics.plot_roc`
581
582
    :param thresholds_data: Thresholds by attribute from the
583
                            function
584
                            :func:`~ethically.interventions
585
                            .threshold.find_thresholds`.
586
    :type thresholds_data: dict
587
    :param str title: Title of the generated plot.
588
    :param ax: The axes upon which to plot the curve.
589
               If `None`, the plot is drawn on a new set of axes.
590
    :param tuple figsize: Tuple denoting figure size of the plot
591
                          e.g. (6, 6).
592
    :param title_fontsize: Matplotlib-style fontsizes.
593
                          Use e.g. 'small', 'medium', 'large'
594
                          or integer-values.
595
    :param text_fontsize: Matplotlib-style fontsizes.
596
                          Use e.g. 'small', 'medium', 'large'
597
                          or integer-values.
598
    :return: The axes on which the plot was drawn.
599
    :rtype: :class:`matplotlib.axes.Axes`
600
601
    """
602
603
    if ax is None:
604
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
605
606
    ax.set_title(title, fontsize=title_fontsize)
607
608
    costs = {group.replace('_', ' ').title(): cost
609
             for group, (_, _, cost, *_) in thresholds_data.items()}
610
611
    (pd.Series(costs)
612
     .sort_values(ascending=False)
613
     .plot(kind='barh', ax=ax))
614
615
    ax.set_xlabel('Cost', fontsize=text_fontsize)
616
    ax.set_ylabel('Threshold', fontsize=text_fontsize)
617
    ax.tick_params(labelsize=text_fontsize)
618
619
    return ax
620