Passed
Pull Request — master (#14)
by Shlomi
04:42 queued 02:29
created

ethically.fairness.interventions.threshold   A

Complexity

Total Complexity 41

Size/Duplication

Total Lines 619
Duplicated Lines 18.9 %

Importance

Changes 0
Metric Value
eloc 258
dl 117
loc 619
rs 9.1199
c 0
b 0
f 0
wmc 41

18 Functions

Rating   Name   Duplication   Size   Complexity  
A _ternary_search_float() 0 14 3
A plot_costs() 0 46 2
A find_fnr_thresholds() 59 59 2
A _ternary_search_domain() 0 25 4
A find_single_threshold() 0 44 2
A get_fnr_indices() 0 10 2
A find_min_cost_thresholds() 0 40 2
A plot_roc_curves_thresholds() 0 50 2
A _first_index_above() 0 8 1
A get_acceptance_rate_indices() 0 15 2
A find_separation_thresholds() 0 25 1
A find_independence_thresholds() 58 58 2
A plot_fpt_tpr() 0 51 3
A _calc_acceptance_rate() 0 3 1
A _find_feasible_roc() 0 13 4
A _extract_threshold() 0 2 1
A _cost_function() 0 14 1
B find_thresholds() 0 68 6

How to fix   Duplicated Code    Complexity   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

Complexity

 Tip:   Before tackling complexity, make sure that you eliminate any duplication first. This often can reduce the size of classes significantly.

Complex classes like ethically.fairness.interventions.threshold often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Post-processing fairness intervension by choosing thresholds.
3
4
There are multiple definitions for choosing the thresholds:
5
6
1. Single threshold for all the sensitive attribute values
7
   that minimizes cost.
8
2. A threshold for each sensitive attribute value
9
   that minimize cost.
10
3. A threshold for each sensitive attribute value
11
   that achieve independence and minimize cost.
12
4. A threshold for each sensitive attribute value
13
   that achieve equal FNR (equal opportunity) and minimize cost.
14
5. A threshold for each sensitive attribute value
15
   that achieve seperation (equalized odds) and minimize cost.
16
17
The code is based on `fairmlbook repository <https://github.com/fairmlbook/fairmlbook.github.io>`_.
18
19
References:
20
    - Hardt, M., Price, E., & Srebro, N. (2016).
21
      Equality of opportunity in supervised learning.
22
      In Advances in neural information processing systems
23
      (pp. 3315-3323).
24
    - `Attacking discrimination with
25
      smarter machine learning by Google
26
      <https://research.google.com/bigpicture/attacking-discrimination-in-ml/>`_.
27
28
"""
29
30
# pylint: disable=no-name-in-module
31
32
import matplotlib.pylab as plt
33
import numpy as np
34
import pandas as pd
35
from scipy.spatial import Delaunay
36
37
from ethically.fairness.metrics.visualization import plot_roc_curves
38
39
40
def _ternary_search_float(f, left, right, tol):
41
    """Trinary search: minimize f(x) over [left, right], to within +/-tol in x.
42
43
    Works assuming f is quasiconvex.
44
45
    """
46
    while right - left > tol:
47
        left_third = (2 * left + right) / 3
48
        right_third = (left + 2 * right) / 3
49
        if f(left_third) < f(right_third):
50
            right = right_third
51
        else:
52
            left = left_third
53
    return (right + left) / 2
54
55
56
def _ternary_search_domain(f, domain):
57
    """Trinary search: minimize f(x) over a domain (sequence).
58
59
    Works assuming f is quasiconvex and domain is ascending sorted.
60
61
    """
62
    left = 0
63
    right = len(domain) - 1
64
    changed = True
65
66
    while changed and left != right:
67
68
        changed = False
69
70
        left_third = (2 * left + right) // 3
71
        right_third = (left + 2 * right) // 3
72
73
        if f(domain[left_third]) < f(domain[right_third]):
74
            right = right_third - 1
75
            changed = True
76
        else:
77
            left = left_third + 1
78
            changed = True
79
80
    return domain[(left + right) // 2]
81
82
83
def _cost_function(fpr, tpr, base_rate, cost_matrix):
84
    """Compute the cost of given (fpr, tpr).
85
86
    [[tn, fp], [fn, tp]]
87
    """
88
89
    fp = fpr * (1 - base_rate)
90
    tn = (1 - base_rate) - fp
91
    tp = tpr * base_rate
92
    fn = base_rate - tp
93
94
    conf_matrix = np.array([tn, fp, fn, tp])
95
96
    return (conf_matrix * np.array(cost_matrix).ravel()).sum()
97
98
99
def _extract_threshold(roc_curves):
100
    return next(iter(roc_curves.values()))[2]
101
102
103
def _first_index_above(array, value):
104
    """Find the smallest index i for which array[i] > value.
105
106
    If no such value exists, return len(array).
107
    """
108
    array = np.array(array)
109
    v = np.concatenate([array > value, np.ones_like(array[-1:])])
110
    return np.argmax(v, axis=0)
111
112
113
def _calc_acceptance_rate(fpr, tpr, base_rate):
114
    return 1 - ((fpr * (1 - base_rate)
115
                 + tpr * base_rate))
116
117
118
def find_single_threshold(roc_curves, base_rates, proportions,
119
                          cost_matrix):
120
    """Compute single threshold that minimizes cost.
121
122
    :param roc_curves: Receiver operating characteristic (ROC)
123
                       by attribute.
124
    :type roc_curves: dict
125
    :param base_rates: Base rate by attribute.
126
    :type base_rates: dict
127
    :param proportions: Proportion of each attribute value.
128
    :type proportions: dict
129
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
130
    :type cost_matrix: sequence
131
    :return: Threshold, FPR and TPR by attribute and cost value.
132
    :rtype: tuple
133
134
    """
135
136
    def total_cost_function(index):
137
        total_cost = 0
138
139
        for group, roc in roc_curves.items():
140
            fpr = roc[0][index]
141
            tpr = roc[1][index]
142
143
            group_cost = _cost_function(fpr, tpr,
144
                                        base_rates[group], cost_matrix)
145
            group_cost *= proportions[group]
146
147
            total_cost += group_cost
148
149
        return -total_cost
150
151
    thresholds = _extract_threshold(roc_curves)
152
153
    cutoff_index = _ternary_search_domain(total_cost_function,
154
                                          range(len(thresholds)))
155
156
    fpr_tpr = {group: (roc[0][cutoff_index], roc[1][cutoff_index])
157
               for group, roc in roc_curves.items()}
158
159
    cost = total_cost_function(cutoff_index)
160
161
    return thresholds[cutoff_index], fpr_tpr, cost
162
163
164
def find_min_cost_thresholds(roc_curves, base_rates, cost_matrix):
165
    """Compute thresholds by attribute values that minimize cost.
166
167
    :param roc_curves: Receiver operating characteristic (ROC)
168
                       by attribute.
169
    :type roc_curves: dict
170
    :param base_rates: Base rate by attribute.
171
    :type base_rates: dict
172
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
173
    :type cost_matrix: sequence
174
    :return: Thresholds, FPR and TPR by attribute and cost value.
175
    :rtype: tuple
176
177
    """
178
    # pylint: disable=cell-var-from-loop
179
180
    cutoffs = {}
181
    fpr_tpr = {}
182
183
    cost = 0
184
    thresholds = _extract_threshold(roc_curves)
185
186
    for group, roc in roc_curves.items():
187
        def group_cost_function(index):
188
            fpr = roc[0][index]
189
            tpr = roc[1][index]
190
            return -_cost_function(fpr, tpr,
191
                                   base_rates[group], cost_matrix)
192
193
        threshold_index = _ternary_search_domain(group_cost_function,
194
                                                 range(len(thresholds)))
195
196
        cutoffs[group] = thresholds[threshold_index]
197
198
        fpr_tpr[group] = (roc[0][threshold_index],
199
                          roc[1][threshold_index])
200
201
        cost += group_cost_function(threshold_index)
202
203
    return cutoffs, fpr_tpr, cost
204
205
206
def get_acceptance_rate_indices(roc_curves, base_rates,
207
                                acceptance_rate_value):
208
    indices = {}
209
    for group, roc in roc_curves.items():
210
        # can be calculated outside the function
211
        acceptance_rates = _calc_acceptance_rate(fpr=roc[0],
212
                                                 tpr=roc[1],
213
                                                 base_rate=base_rates[group])
214
215
        index = _first_index_above(acceptance_rates,
216
                                   (1 - acceptance_rate_value)) - 2
217
218
        indices[group] = index
219
220
    return indices
221
222
223 View Code Duplication
def find_independence_thresholds(roc_curves, base_rates, proportions,
224
                                 cost_matrix):
225
    """Compute thresholds that achieve independence and minimize cost.
226
227
    :param roc_curves: Receiver operating characteristic (ROC)
228
                       by attribute.
229
    :type roc_curves: dict
230
    :param base_rates: Base rate by attribute.
231
    :type base_rates: dict
232
    :param proportions: Proportion of each attribute value.
233
    :type proportions: dict
234
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
235
    :type cost_matrix: sequence
236
    :return: Thresholds, FPR and TPR by attribute and cost value.
237
    :rtype: tuple
238
239
    """
240
241
    cutoffs = {}
242
243
    def total_cost_function(acceptance_rate_value):
244
        # todo: move demo here + multiple cost
245
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
246
                                              acceptance_rate_value)
247
248
        total_cost = 0
249
250
        for group, roc in roc_curves.items():
251
            index = indices[group]
252
253
            fpr = roc[0][index]
254
            tpr = roc[1][index]
255
256
            group_cost = _cost_function(fpr, tpr,
257
                                        base_rates[group],
258
                                        cost_matrix)
259
            group_cost *= proportions[group]
260
261
            total_cost += group_cost
262
263
        return -total_cost
264
265
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
266
                                                     0, 1, 1e-3)
267
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
268
                                                    acceptance_rate_min_cost)
269
270
    thresholds = _extract_threshold(roc_curves)
271
272
    cutoffs = {group: thresholds[threshold_index]
273
               for group, threshold_index
274
               in threshold_indices.items()}
275
276
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
277
                       roc[1][threshold_indices[group]])
278
               for group, roc in roc_curves.items()}
279
280
    return cutoffs, fpr_tpr, acceptance_rate_min_cost
281
282
283
def get_fnr_indices(roc_curves, fnr_value):
284
    indices = {}
285
    for group, roc in roc_curves.items():
286
        tprs = roc[1]
287
        index = _first_index_above(1 - tprs,
288
                                   (1 - fnr_value)) - 1
289
290
        indices[group] = index
291
292
    return indices
293
294
295 View Code Duplication
def find_fnr_thresholds(roc_curves, base_rates, proportions,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
296
                        cost_matrix):
297
    """Compute thresholds that achieve equal FNRs and minimize cost.
298
299
    Also known as **equal opportunity**.
300
301
    :param roc_curves: Receiver operating characteristic (ROC)
302
                       by attribute.
303
    :type roc_curves: dict
304
    :param base_rates: Base rate by attribute.
305
    :type base_rates: dict
306
    :param proportions: Proportion of each attribute value.
307
    :type proportions: dict
308
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
309
    :type cost_matrix: sequence
310
    :return: Thresholds, FPR and TPR by attribute and cost value.
311
    :rtype: tuple
312
313
    """
314
315
    cutoffs = {}
316
317
    def total_cost_function(fnr_value):
318
        # todo: move demo here + multiple cost
319
        indices = get_fnr_indices(roc_curves, fnr_value)
320
321
        total_cost = 0
322
323
        for group, roc in roc_curves.items():
324
            index = indices[group]
325
326
            fpr = roc[0][index]
327
            tpr = roc[1][index]
328
329
            group_cost = _cost_function(fpr, tpr,
330
                                        base_rates[group],
331
                                        cost_matrix)
332
            group_cost *= proportions[group]
333
334
            total_cost += group_cost
335
336
        return -total_cost
337
338
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
339
                                               0, 1, 1e-3)
340
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
341
342
    cost = total_cost_function(fnr_value_min_cost)
343
344
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
345
                       roc[1][threshold_indices[group]])
346
               for group, roc in roc_curves.items()}
347
348
    thresholds = _extract_threshold(roc_curves)
349
    cutoffs = {group: thresholds[threshold_index]
350
               for group, threshold_index
351
               in threshold_indices.items()}
352
353
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
354
355
356
def _find_feasible_roc(roc_curves):
357
    polygons = [Delaunay(list(zip(fprs, tprs)))
358
                for group, (fprs, tprs, _) in roc_curves.items()]
359
360
    feasible_points = []
361
362
    for poly in polygons:
363
        for p in poly.points:
364
365
            if all(poly2.find_simplex(p) != -1 for poly2 in polygons):
366
                feasible_points.append(p)
367
368
    return np.array(feasible_points)
369
370
371
def find_separation_thresholds(roc_curves, base_rate, cost_matrix):
372
    """Compute thresholds that achieve separation and minimize cost.
373
374
    Also known as **equalized odds**.
375
376
    :param roc_curves: Receiver operating characteristic (ROC)
377
                       by attribute.
378
    :type roc_curves: dict
379
    :param base_rate: Overall base rate.
380
    :type base_rate: float
381
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
382
    :type cost_matrix: sequence
383
    :return: Thresholds, FPR and TPR by attribute and cost value.
384
    :rtype: tuple
385
386
    """
387
388
    feasible_points = _find_feasible_roc(roc_curves)
389
390
    cost, (best_fpr, best_tpr) = max((_cost_function(fpr, tpr, base_rate,
391
                                                     cost_matrix),
392
                                      (fpr, tpr))
393
                                     for fpr, tpr in feasible_points)
394
395
    return {}, {'': (best_fpr, best_tpr)}, cost
396
397
398
def find_thresholds(roc_curves, proportions, base_rate,
399
                    base_rates, cost_matrix,
400
                    with_single=True, with_min_cost=True,
401
                    with_independence=True, with_fnr=True,
402
                    with_separation=True):
403
    """Compute thresholds that achieve various criteria and minimize cost.
404
405
    :param roc_curves: Receiver operating characteristic (ROC)
406
                       by attribute.
407
    :type roc_curves: dict
408
    :param proportions: Proportion of each attribute value.
409
    :type proportions: dict
410
    :param base_rate: Overall base rate.
411
    :type base_rate: float
412
    :param base_rates: Base rate by attribute.
413
    :type base_rates: dict
414
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
415
    :type cost_matrix: sequence
416
417
    :param with_single: Compute single threshold.
418
    :type with_single: bool
419
    :param with_min_cost: Compute minimum cost thresholds.
420
    :type with_min_cost: bool
421
    :param with_independence: Compute independence thresholds.
422
    :type with_independence: bool
423
    :param with_fnr: Compute FNR thresholds.
424
    :type with_fnr: bool
425
    :param with_separation: Compute separation thresholds.
426
    :type with_separation: bool
427
428
    :return: Dictionary of threshold criteria,
429
             and for each criterion:
430
             thresholds, FPR and TPR by attribute and cost value.
431
    :rtype: dict
432
433
    """
434
435
    thresholds = {}
436
437
    if with_single:
438
        thresholds['single'] = find_single_threshold(roc_curves,
439
                                                     base_rates,
440
                                                     proportions,
441
                                                     cost_matrix)
442
443
    if with_min_cost:
444
        thresholds['min_cost'] = find_min_cost_thresholds(roc_curves,
445
                                                          base_rates,
446
                                                          cost_matrix)
447
448
    if with_independence:
449
        thresholds['independence'] = find_independence_thresholds(roc_curves,
450
                                                                  base_rates,
451
                                                                  proportions,
452
                                                                  cost_matrix)
453
454
    if with_fnr:
455
        thresholds['fnr'] = find_fnr_thresholds(roc_curves,
456
                                                base_rates,
457
                                                proportions,
458
                                                cost_matrix)
459
460
    if with_separation:
461
        thresholds['separation'] = find_separation_thresholds(roc_curves,
462
                                                              base_rate,
463
                                                              cost_matrix)
464
465
    return thresholds
466
467
468
def plot_roc_curves_thresholds(roc_curves, thresholds_data,
469
                               aucs=None,
470
                               title='ROC Curves by Attribute',
471
                               ax=None, figsize=None,
472
                               title_fontsize='large',
473
                               text_fontsize='medium'):
474
    """Generate the ROC curves by attribute with thresholds.
475
476
    Based on :func:`skplt.metrics.plot_roc`
477
478
    :param roc_curves: Receiver operating characteristic (ROC)
479
                       by attribute.
480
    :type roc_curves: dict
481
    :param thresholds_data: Thresholds by attribute from the
482
                            function
483
                            :func:`~ethically.interventions
484
                            .threshold.find_thresholds`.
485
    :type thresholds_data: dict
486
    :param aucs: Area Under the ROC (AUC) by attribute.
487
    :type aucs: dict
488
    :param str title: Title of the generated plot.
489
    :param ax: The axes upon which to plot the curve.
490
               If `None`, the plot is drawn on a new set of axes.
491
    :param tuple figsize: Tuple denoting figure size of the plot
492
                          e.g. (6, 6).
493
    :param title_fontsize: Matplotlib-style fontsizes.
494
                          Use e.g. 'small', 'medium', 'large'
495
                          or integer-values.
496
    :param text_fontsize: Matplotlib-style fontsizes.
497
                          Use e.g. 'small', 'medium', 'large'
498
                          or integer-values.
499
    :return: The axes on which the plot was drawn.
500
    :rtype: :class:`matplotlib.axes.Axes`
501
502
    """
503
504
    ax = plot_roc_curves(roc_curves, aucs,
505
                         title, ax, figsize, title_fontsize, text_fontsize)
506
507
    MARKERS = ['o', '^', 'x', '+', 'p']
508
509
    for (name, data), marker in zip(thresholds_data.items(), MARKERS):
510
        label = name.replace('_', ' ').title()
511
        ax.scatter(*zip(*data[1].values()),
512
                   marker=marker, color='k', label=label,
513
                   zorder=float('inf'))
514
515
    plt.legend()
516
517
    return ax
518
519
520
def plot_fpt_tpr(roc_curves,
521
                 title='FPR-TPR Curves by Attribute',
522
                 ax=None, figsize=None,
523
                 title_fontsize='large', text_fontsize='medium'):
524
    """Generate FPR and TPR curves by thresholds and by attribute.
525
526
    Based on :func:`skplt.metrics.plot_roc`
527
528
    :param roc_curves: Receiver operating characteristic (ROC)
529
                       by attribute.
530
    :type roc_curves: dict
531
    :param str title: Title of the generated plot.
532
    :param ax: The axes upon which to plot the curve.
533
               If `None`, the plot is drawn on a new set of axes.
534
    :param tuple figsize: Tuple denoting figure size of the plot
535
                          e.g. (6, 6).
536
    :param title_fontsize: Matplotlib-style fontsizes.
537
                          Use e.g. 'small', 'medium', 'large'
538
                          or integer-values.
539
    :param text_fontsize: Matplotlib-style fontsizes.
540
                          Use e.g. 'small', 'medium', 'large'
541
                          or integer-values.
542
    :return: The axes on which the plot was drawn.
543
    :rtype: :class:`matplotlib.axes.Axes`
544
545
    """
546
547
    if ax is None:
548
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
549
550
    ax.set_title(title, fontsize=title_fontsize)
551
552
    thresholds = _extract_threshold(roc_curves)
553
554
    prop_cycle = plt.rcParams['axes.prop_cycle']
555
    colors = prop_cycle.by_key()['color']
556
557
    for (group, roc), color in zip(roc_curves.items(), colors):
558
        plt.plot(thresholds, roc[0], '-',
559
                 label='{} - FPR'.format(group), color=color)
560
        plt.plot(thresholds, roc[1], '--',
561
                 label='{} - TPR'.format(group), color=color)
562
        plt.legend()
563
564
    ax.set_ylim([0.0, 1.05])
565
    ax.set_xlabel('Threshold', fontsize=text_fontsize)
566
    ax.set_ylabel('Probability', fontsize=text_fontsize)
567
    ax.tick_params(labelsize=text_fontsize)
568
    ax.legend(fontsize=text_fontsize)
569
570
    return ax
571
572
573
def plot_costs(thresholds_data,
574
               title='Cost by Threshold',
575
               ax=None, figsize=None,
576
               title_fontsize='large', text_fontsize='medium'):
577
    """Plot cost by threshold definition and by attribute.
578
579
    Based on :func:`skplt.metrics.plot_roc`
580
581
    :param thresholds_data: Thresholds by attribute from the
582
                            function
583
                            :func:`~ethically.interventions
584
                            .threshold.find_thresholds`.
585
    :type thresholds_data: dict
586
    :param str title: Title of the generated plot.
587
    :param ax: The axes upon which to plot the curve.
588
               If `None`, the plot is drawn on a new set of axes.
589
    :param tuple figsize: Tuple denoting figure size of the plot
590
                          e.g. (6, 6).
591
    :param title_fontsize: Matplotlib-style fontsizes.
592
                          Use e.g. 'small', 'medium', 'large'
593
                          or integer-values.
594
    :param text_fontsize: Matplotlib-style fontsizes.
595
                          Use e.g. 'small', 'medium', 'large'
596
                          or integer-values.
597
    :return: The axes on which the plot was drawn.
598
    :rtype: :class:`matplotlib.axes.Axes`
599
600
    """
601
602
    if ax is None:
603
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
604
605
    ax.set_title(title, fontsize=title_fontsize)
606
607
    costs = {group.replace('_', ' ').title(): cost
608
             for group, (_, _, cost, *_) in thresholds_data.items()}
609
610
    (pd.Series(costs)
611
     .sort_values(ascending=False)
612
     .plot(kind='barh', ax=ax))
613
614
    ax.set_xlabel('Cost', fontsize=text_fontsize)
615
    ax.set_ylabel('Threshold', fontsize=text_fontsize)
616
    ax.tick_params(labelsize=text_fontsize)
617
618
    return ax
619