_cost_function()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 4
dl 0
loc 14
rs 10
c 0
b 0
f 0
1
"""
2
Post-processing fairness intervension by choosing thresholds.
3
4
There are multiple definitions for choosing the thresholds:
5
6
1. Single threshold for all the sensitive attribute values
7
   that minimizes cost.
8
2. A threshold for each sensitive attribute value
9
   that minimize cost.
10
3. A threshold for each sensitive attribute value
11
   that achieve independence and minimize cost.
12
4. A threshold for each sensitive attribute value
13
   that achieve equal FNR (equal opportunity) and minimize cost.
14
5. A threshold for each sensitive attribute value
15
   that achieve separation (equalized odds) and minimize cost.
16
17
The code is based on `fairmlbook repository <https://github.com/fairmlbook/fairmlbook.github.io>`_.
18
19
References:
20
    - Hardt, M., Price, E., & Srebro, N. (2016).
21
      `Equality of opportunity in supervised learning
22
      <https://arxiv.org/abs/1610.02413>`_
23
      In Advances in neural information processing systems
24
      (pp. 3315-3323).
25
    - `Attacking discrimination with
26
      smarter machine learning by Google
27
      <https://research.google.com/bigpicture/attacking-discrimination-in-ml/>`_.
28
29
"""
30
31
# pylint: disable=no-name-in-module,ungrouped-imports
32
33
from collections import Counter
34
35
import matplotlib.pylab as plt
36
import numpy as np
37
import pandas as pd
38
import seaborn as sns
39
from matplotlib.ticker import AutoMinorLocator
40
from scipy.spatial import Delaunay
41
42
from responsibly.fairness.metrics.score import roc_curve_by_attr
43
from responsibly.fairness.metrics.utils import _groupby_y_x_sens
44
from responsibly.fairness.metrics.visualization import plot_roc_curves
45
46
47
TRINARY_SEARCH_TOL = 1e-3
48
49
50
def _strictly_increasing(arr):
51
    return (np.diff(arr) >= 0).all()
52
53
54
def _titlify(text):
55
    text = text.replace('_', ' ').title()
56
    if text == 'Fnr':
57
        text = 'FNR'
58
    return text
59
60
61
def _ternary_search_float(f, left, right, tol):
62
    """Trinary search: minimize f(x) over [left, right], to within +/-tol in x.
63
64
    Works assuming f is quasiconvex.
65
    """
66
67
    while right - left > tol:
68
        left_third = (2 * left + right) / 3
69
        right_third = (left + 2 * right) / 3
70
        if f(left_third) < f(right_third):
71
            right = right_third
72
        else:
73
            left = left_third
74
    return (right + left) / 2
75
76
77
def _ternary_search_domain(f, domain):
78
    """Trinary search: minimize f(x) over a domain (sequence).
79
80
    Works assuming f is quasiconvex and domain is ascending sorted.
81
82
    BUGGY, DO NOT USE
83
84
    >>> arr = np.concatenate([np.arange(10, 2, -1), np.arange(2, 20)])
85
    >>> t1 = _ternary_search_domain(lambda t: arr[t], range(len(arr)))
86
    >>> t2 = np.argmin(arr)
87
88
    >>> assert t1 == t2
89
    >>> assert arr[t1] == arr[t2]
90
    """
91
92
    left = 0
93
    right = len(domain) - 1
94
    changed = True
95
96
    while changed and left != right:
97
98
        changed = False
99
100
        left_third = (2 * left + right) // 3
101
        right_third = (left + 2 * right) // 3
102
103
        if f(domain[left_third]) < f(domain[right_third]):
104
            right = right_third - 1
105
            changed = True
106
        else:
107
            left = left_third + 1
108
            changed = True
109
110
    return domain[(left + right) // 2]
111
112
113
def _cost_function(fpr, tpr, base_rate, cost_matrix):
114
    """Compute the cost of given (fpr, tpr).
115
116
    [[tn, fp], [fn, tp]]
117
    """
118
119
    fp = fpr * (1 - base_rate)
120
    tn = (1 - base_rate) - fp
121
    tp = tpr * base_rate
122
    fn = base_rate - tp
123
124
    conf_matrix = np.array([tn, fp, fn, tp])
125
126
    return (conf_matrix * np.array(cost_matrix).ravel()).sum()
127
128
129
def _extract_threshold(roc_curves):
130
    return next(iter(roc_curves.values()))[2]
131
132
133
def _first_index_above(arr, value):
134
    """Find the smallest index i for which array[i] > value.
135
136
    If no such value exists, return len(array).
137
    """
138
139
    assert _strictly_increasing(arr), (
140
        'arr should be stricktly increasing.')
141
142
    arr = np.array(arr)
143
    v = np.concatenate([arr > value, [1]])
144
    return np.argmax(v, axis=0)
145
146
147
def _calc_acceptance_rate(fpr, tpr, base_rate):
148
    return (fpr * (1 - base_rate)
149
            + tpr * base_rate)
150
151
152
def find_single_threshold(roc_curves, base_rates, proportions,
153
                          cost_matrix):
154
    """Compute single threshold that minimizes cost.
155
156
    :param roc_curves: Receiver operating characteristic (ROC)
157
                       by attribute.
158
    :type roc_curves: dict
159
    :param base_rates: Base rate by attribute.
160
    :type base_rates: dict
161
    :param proportions: Proportion of each attribute value.
162
    :type proportions: dict
163
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
164
    :type cost_matrix: sequence
165
    :return: Threshold, FPR and TPR by attribute and cost value.
166
    :rtype: tuple
167
168
    """
169
170
    def total_cost_function(index):
171
        total_cost = 0
172
173
        for group, roc in roc_curves.items():
174
            fpr = roc[0][index]
175
            tpr = roc[1][index]
176
177
            group_cost = _cost_function(fpr, tpr,
178
                                        base_rates[group], cost_matrix)
179
180
            group_cost *= proportions[group]
181
182
            total_cost += group_cost
183
184
        return -total_cost
185
186
    thresholds = _extract_threshold(roc_curves)
187
188
    cost_per_threshold = [total_cost_function(index)
189
                          for index in range(len(thresholds))]
190
    cutoff_index = np.argmin(cost_per_threshold)
191
192
    fpr_tpr = {group: (roc[0][cutoff_index], roc[1][cutoff_index])
193
               for group, roc in roc_curves.items()}
194
195
    cost = total_cost_function(cutoff_index)
196
197
    return thresholds[cutoff_index], fpr_tpr, cost
198
199
200
def find_min_cost_thresholds(roc_curves, base_rates, proportions, cost_matrix):
201
    """Compute thresholds by attribute values that minimize cost.
202
203
    :param roc_curves: Receiver operating characteristic (ROC)
204
                       by attribute.
205
    :type roc_curves: dict
206
    :param base_rates: Base rate by attribute.
207
    :type base_rates: dict
208
    :param proportions: Proportion of each attribute value.
209
    :type proportions: dict
210
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
211
    :type cost_matrix: sequence
212
    :return: Thresholds, FPR and TPR by attribute and cost value.
213
    :rtype: tuple
214
215
    """
216
    # pylint: disable=cell-var-from-loop
217
218
    cutoffs = {}
219
    fpr_tpr = {}
220
221
    cost = 0
222
    thresholds = _extract_threshold(roc_curves)
223
224
    for group, roc in roc_curves.items():
225
        def group_cost_function(index):
226
            fpr = roc[0][index]
227
            tpr = roc[1][index]
228
            return -_cost_function(fpr, tpr,
229
                                   base_rates[group], cost_matrix)
230
231
        cost_per_threshold = [group_cost_function(index)
232
                              for index in range(len(thresholds))]
233
        cutoff_index = np.argmin(cost_per_threshold)
234
235
        cutoffs[group] = thresholds[cutoff_index]
236
237
        fpr_tpr[group] = (roc[0][cutoff_index],
238
                          roc[1][cutoff_index])
239
240
        cost += group_cost_function(cutoff_index) * proportions[group]
241
242
    return cutoffs, fpr_tpr, cost
243
244
245
def get_acceptance_rate_indices(roc_curves, base_rates,
246
                                acceptance_rate_value):
247
    indices = {}
248
249
    for group, roc in roc_curves.items():
250
        # can be calculated outside the function
251
        acceptance_rates = _calc_acceptance_rate(fpr=roc[0],
252
                                                 tpr=roc[1],
253
                                                 base_rate=base_rates[group])
254
255
        index = _first_index_above(acceptance_rates,
256
                                   acceptance_rate_value)
257
258
        indices[group] = index
259
260
    return indices
261
262
263 View Code Duplication
def find_independence_thresholds(roc_curves, base_rates, proportions,
264
                                 cost_matrix):
265
    """Compute thresholds that achieve independence and minimize cost.
266
267
    :param roc_curves: Receiver operating characteristic (ROC)
268
                       by attribute.
269
    :type roc_curves: dict
270
    :param base_rates: Base rate by attribute.
271
    :type base_rates: dict
272
    :param proportions: Proportion of each attribute value.
273
    :type proportions: dict
274
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
275
    :type cost_matrix: sequence
276
    :return: Thresholds, FPR and TPR by attribute and cost value.
277
    :rtype: tuple
278
279
    """
280
281
    cutoffs = {}
282
283
    def total_cost_function(acceptance_rate_value):
284
        # todo: move demo here + multiple cost
285
        #       + refactor - use threshold to calculate
286
        #         acceptance_rate_value
287
        indices = get_acceptance_rate_indices(roc_curves, base_rates,
288
                                              acceptance_rate_value)
289
290
        total_cost = 0
291
292
        for group, roc in roc_curves.items():
293
            index = indices[group]
294
295
            fpr = roc[0][index]
296
            tpr = roc[1][index]
297
298
            group_cost = _cost_function(fpr, tpr,
299
                                        base_rates[group],
300
                                        cost_matrix)
301
302
            group_cost *= proportions[group]
303
304
            total_cost += group_cost
305
306
        return -total_cost
307
308
    acceptance_rate_min_cost = _ternary_search_float(total_cost_function,
309
                                                     0, 1, TRINARY_SEARCH_TOL)
310
311
    cost = total_cost_function(acceptance_rate_min_cost)
312
313
    threshold_indices = get_acceptance_rate_indices(roc_curves, base_rates,
314
                                                    acceptance_rate_min_cost)
315
    thresholds = _extract_threshold(roc_curves)
316
317
    cutoffs = {group: thresholds[threshold_index]
318
               for group, threshold_index
319
               in threshold_indices.items()}
320
321
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
322
                       roc[1][threshold_indices[group]])
323
               for group, roc in roc_curves.items()}
324
325
    return cutoffs, fpr_tpr, cost, acceptance_rate_min_cost
326
327
328
def get_fnr_indices(roc_curves, fnr_value):
329
    indices = {}
330
331
    tpr_value = 1 - fnr_value
332
333
    for group, roc in roc_curves.items():
334
        tprs = roc[1]
335
        index = _first_index_above(tprs,
336
                                   tpr_value) - 1
337
        index = max(0, index)
338
        indices[group] = index
339
340
    return indices
341
342
343 View Code Duplication
def find_fnr_thresholds(roc_curves, base_rates, proportions,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
344
                        cost_matrix):
345
    """Compute thresholds that achieve equal FNRs and minimize cost.
346
347
    Also known as **equal opportunity**.
348
349
    :param roc_curves: Receiver operating characteristic (ROC)
350
                       by attribute.
351
    :type roc_curves: dict
352
    :param base_rates: Base rate by attribute.
353
    :type base_rates: dict
354
    :param proportions: Proportion of each attribute value.
355
    :type proportions: dict
356
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
357
    :type cost_matrix: sequence
358
    :return: Thresholds, FPR and TPR by attribute and cost value.
359
    :rtype: tuple
360
361
    """
362
363
    cutoffs = {}
364
365
    def total_cost_function(fnr_value):
366
        # todo: move demo here + multiple cost
367
        indices = get_fnr_indices(roc_curves, fnr_value)
368
369
        total_cost = 0
370
371
        for group, roc in roc_curves.items():
372
            index = indices[group]
373
374
            fpr = roc[0][index]
375
            tpr = roc[1][index]
376
377
            group_cost = _cost_function(fpr, tpr,
378
                                        base_rates[group],
379
                                        cost_matrix)
380
            group_cost *= proportions[group]
381
382
            total_cost += group_cost
383
384
        return -total_cost
385
386
    fnr_value_min_cost = _ternary_search_float(total_cost_function,
387
                                               0, 1,
388
                                               TRINARY_SEARCH_TOL)
389
390
    threshold_indices = get_fnr_indices(roc_curves, fnr_value_min_cost)
391
392
    cost = total_cost_function(fnr_value_min_cost)
393
394
    fpr_tpr = {group: (roc[0][threshold_indices[group]],
395
                       roc[1][threshold_indices[group]])
396
               for group, roc in roc_curves.items()}
397
398
    thresholds = _extract_threshold(roc_curves)
399
    cutoffs = {group: thresholds[threshold_index]
400
               for group, threshold_index
401
               in threshold_indices.items()}
402
403
    return cutoffs, fpr_tpr, cost, fnr_value_min_cost
404
405
406
def _find_feasible_roc(roc_curves):
407
    polygons = [Delaunay(list(zip(fprs, tprs)))
408
                for group, (fprs, tprs, _) in roc_curves.items()]
409
410
    feasible_points = []
411
412
    for poly in polygons:
413
        for p in poly.points:
414
415
            if all(poly2.find_simplex(p) != -1 for poly2 in polygons):
416
                feasible_points.append(p)
417
418
    return np.array(feasible_points)
419
420
421
def find_separation_thresholds(roc_curves, base_rate, cost_matrix):
422
    """Compute thresholds that achieve separation and minimize cost.
423
424
    Also known as **equalized odds**.
425
426
    :param roc_curves: Receiver operating characteristic (ROC)
427
                       by attribute.
428
    :type roc_curves: dict
429
    :param base_rate: Overall base rate.
430
    :type base_rate: float
431
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
432
    :type cost_matrix: sequence
433
    :return: Thresholds, FPR and TPR by attribute and cost value.
434
    :rtype: tuple
435
436
    """
437
438
    feasible_points = _find_feasible_roc(roc_curves)
439
440
    cost, (best_fpr, best_tpr) = max((_cost_function(fpr, tpr, base_rate,
441
                                                     cost_matrix),
442
                                      (fpr, tpr))
443
                                     for fpr, tpr in feasible_points)
444
    cost = - cost
445
446
    return {}, {'': (best_fpr, best_tpr)}, cost
447
448
449
def find_thresholds(roc_curves, proportions, base_rate,
450
                    base_rates, cost_matrix,
451
                    with_single=True, with_min_cost=True,
452
                    with_independence=True, with_fnr=True,
453
                    with_separation=True):
454
    """Compute thresholds that achieve various criteria and minimize cost.
455
456
    :param roc_curves: Receiver operating characteristic (ROC)
457
                       by attribute.
458
    :type roc_curves: dict
459
    :param proportions: Proportion of each attribute value.
460
    :type proportions: dict
461
    :param base_rate: Overall base rate.
462
    :type base_rate: float
463
    :param base_rates: Base rate by attribute.
464
    :type base_rates: dict
465
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
466
    :type cost_matrix: sequence
467
468
    :param with_single: Compute single threshold.
469
    :type with_single: bool
470
    :param with_min_cost: Compute minimum cost thresholds.
471
    :type with_min_cost: bool
472
    :param with_independence: Compute independence thresholds.
473
    :type with_independence: bool
474
    :param with_fnr: Compute FNR thresholds.
475
    :type with_fnr: bool
476
    :param with_separation: Compute separation thresholds.
477
    :type with_separation: bool
478
479
    :return: Dictionary of threshold criteria,
480
             and for each criterion:
481
             thresholds, FPR and TPR by attribute and cost value.
482
    :rtype: dict
483
484
    """
485
486
    thresholds = {}
487
488
    if with_single:
489
        thresholds['single'] = find_single_threshold(roc_curves,
490
                                                     base_rates,
491
                                                     proportions,
492
                                                     cost_matrix)
493
494
    if with_min_cost:
495
        thresholds['min_cost'] = find_min_cost_thresholds(roc_curves,
496
                                                          base_rates,
497
                                                          proportions,
498
                                                          cost_matrix)
499
500
    if with_independence:
501
        thresholds['independence'] = find_independence_thresholds(roc_curves,
502
                                                                  base_rates,
503
                                                                  proportions,
504
                                                                  cost_matrix)
505
506
    if with_fnr:
507
        thresholds['fnr'] = find_fnr_thresholds(roc_curves,
508
                                                base_rates,
509
                                                proportions,
510
                                                cost_matrix)
511
512
    if with_separation:
513
        thresholds['separation'] = find_separation_thresholds(roc_curves,
514
                                                              base_rate,
515
                                                              cost_matrix)
516
517
    return thresholds
518
519
520
def find_thresholds_by_attr(y_true, y_score, x_sens,
521
                            cost_matrix,
522
                            with_single=True, with_min_cost=True,
523
                            with_independence=True, with_fnr=True,
524
                            with_separation=True,
525
                            pos_label=None, sample_weight=None,
526
                            drop_intermediate=False):
527
    """
528
    Compute thresholds that achieve various criteria and minimize cost.
529
530
    :param y_true: Binary ground truth (correct) target values.
531
    :param y_score: Estimated target score as returned by a classifier.
532
    :param x_sens: Sensitive attribute values corresponded to each
533
                   estimated target.
534
    :param cost_matrix: Cost matrix by [[tn, fp], [fn, tp]].
535
    :type cost_matrix: sequence
536
537
    :param pos_label: Label considered as positive and others
538
                      are considered negative.
539
    :param sample_weight: Sample weights.
540
    :param drop_intermediate: Whether to drop some suboptimal
541
                              thresholds which would not appear on
542
                              a plotted ROC curve.
543
                              This is useful in order to create
544
                              lighter ROC curves.
545
546
    :param with_single: Compute single threshold.
547
    :type with_single: bool
548
    :param with_min_cost: Compute minimum cost thresholds.
549
    :type with_min_cost: bool
550
    :param with_independence: Compute independence thresholds.
551
    :type with_independence: bool
552
    :param with_fnr: Compute FNR thresholds.
553
    :type with_fnr: bool
554
    :param with_separation: Compute separation thresholds.
555
    :type with_separation: bool
556
557
    :return: Dictionary of threshold criteria,
558
             and for each criterion:
559
             thresholds, FPR and TPR by attribute and cost value.
560
    :rtype: dict
561
    """
562
    # pylint: disable=too-many-locals
563
564
    roc_curves = roc_curve_by_attr(y_true, y_score, x_sens,
565
                                   pos_label, sample_weight,
566
                                   drop_intermediate)
567
568
    proportions = {value: count / len(x_sens)
569
                   for value, count in Counter(x_sens).items()}
570
571
    if pos_label is None:
572
        pos_label = 1
573
574
    base_rate = np.mean(y_true == pos_label)
575
    grouped = _groupby_y_x_sens(y_true, y_score, x_sens)
576
577
    base_rates = {x_sens_value: np.mean(group['y_true'] == pos_label)
578
                  for x_sens_value, group in grouped}
579
580
    thresholds_data = find_thresholds(roc_curves,
581
                                      proportions,
582
                                      base_rate,
583
                                      base_rates,
584
                                      cost_matrix,
585
                                      with_single, with_min_cost,
586
                                      with_independence, with_fnr,
587
                                      with_separation)
588
589
    return thresholds_data
590
591
592
def plot_roc_curves_thresholds(roc_curves, thresholds_data,
593
                               aucs=None,
594
                               title='ROC Curves by Attribute',
595
                               ax=None, figsize=None,
596
                               title_fontsize='large',
597
                               text_fontsize='medium'):
598
    """Generate the ROC curves by attribute with thresholds.
599
600
    Based on :func:`skplt.metrics.plot_roc`
601
602
    :param roc_curves: Receiver operating characteristic (ROC)
603
                       by attribute.
604
    :type roc_curves: dict
605
    :param thresholds_data: Thresholds by attribute from the
606
                            function
607
                            :func:`~responsibly.interventions
608
                            .threshold.find_thresholds`.
609
    :type thresholds_data: dict
610
    :param aucs: Area Under the ROC (AUC) by attribute.
611
    :type aucs: dict
612
    :param str title: Title of the generated plot.
613
    :param ax: The axes upon which to plot the curve.
614
               If `None`, the plot is drawn on a new set of axes.
615
    :param tuple figsize: Tuple denoting figure size of the plot
616
                          e.g. (6, 6).
617
    :param title_fontsize: Matplotlib-style fontsizes.
618
                          Use e.g. 'small', 'medium', 'large'
619
                          or integer-values.
620
    :param text_fontsize: Matplotlib-style fontsizes.
621
                          Use e.g. 'small', 'medium', 'large'
622
                          or integer-values.
623
    :return: The axes on which the plot was drawn.
624
    :rtype: :class:`matplotlib.axes.Axes`
625
626
    """
627
628
    ax = plot_roc_curves(roc_curves, aucs,
629
                         title, ax, figsize, title_fontsize, text_fontsize)
630
631
    MARKERS = ['o', '^', 'x', '+', 'p']
632
633
    for (name, data), marker in zip(thresholds_data.items(), MARKERS):
634
        label = _titlify(name)
635
        ax.scatter(*zip(*data[1].values()),
636
                   marker=marker, color='k', label=label,
637
                   zorder=float('inf'))
638
639
    plt.legend()
640
641
    return ax
642
643
644
def plot_fpt_tpr(roc_curves,
645
                 title='FPR-TPR Curves by Attribute',
646
                 ax=None, figsize=None,
647
                 title_fontsize='large', text_fontsize='medium'):
648
    """Generate FPR and TPR curves by thresholds and by attribute.
649
650
    Based on :func:`skplt.metrics.plot_roc`
651
652
    :param roc_curves: Receiver operating characteristic (ROC)
653
                       by attribute.
654
    :type roc_curves: dict
655
    :param str title: Title of the generated plot.
656
    :param ax: The axes upon which to plot the curve.
657
               If `None`, the plot is drawn on a new set of axes.
658
    :param tuple figsize: Tuple denoting figure size of the plot
659
                          e.g. (6, 6).
660
    :param title_fontsize: Matplotlib-style fontsizes.
661
                          Use e.g. 'small', 'medium', 'large'
662
                          or integer-values.
663
    :param text_fontsize: Matplotlib-style fontsizes.
664
                          Use e.g. 'small', 'medium', 'large'
665
                          or integer-values.
666
    :return: The axes on which the plot was drawn.
667
    :rtype: :class:`matplotlib.axes.Axes`
668
669
    """
670
671
    if ax is None:
672
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
673
674
    ax.set_title(title, fontsize=title_fontsize)
675
676
    thresholds = _extract_threshold(roc_curves)
677
678
    prop_cycle = plt.rcParams['axes.prop_cycle']
679
    colors = prop_cycle.by_key()['color']
680
681
    for (group, roc), color in zip(roc_curves.items(), colors):
682
        plt.plot(thresholds, roc[0], '-',
683
                 label='{} - FPR'.format(group), color=color)
684
        plt.plot(thresholds, roc[1], '--',
685
                 label='{} - TPR'.format(group), color=color)
686
        plt.legend()
687
688
    ax.set_ylim([0.0, 1.05])
689
    ax.set_xlabel('Threshold', fontsize=text_fontsize)
690
    ax.set_ylabel('Probability', fontsize=text_fontsize)
691
    ax.tick_params(labelsize=text_fontsize)
692
    ax.legend(fontsize=text_fontsize)
693
694
    return ax
695
696
697
def plot_costs(thresholds_data,
698
               title='Cost by Threshold Strategy',
699
               ax=None, figsize=None,
700
               title_fontsize='large', text_fontsize='medium'):
701
    """Plot cost by threshold definition and by attribute.
702
703
    Based on :func:`skplt.metrics.plot_roc`
704
705
    :param thresholds_data: Thresholds by attribute from the
706
                            function
707
                            :func:`~responsibly.interventions
708
                            .threshold.find_thresholds`.
709
    :type thresholds_data: dict
710
    :param str title: Title of the generated plot.
711
    :param ax: The axes upon which to plot the curve.
712
               If `None`, the plot is drawn on a new set of axes.
713
    :param tuple figsize: Tuple denoting figure size of the plot
714
                          e.g. (6, 6).
715
    :param title_fontsize: Matplotlib-style fontsizes.
716
                          Use e.g. 'small', 'medium', 'large'
717
                          or integer-values.
718
    :param text_fontsize: Matplotlib-style fontsizes.
719
                          Use e.g. 'small', 'medium', 'large'
720
                          or integer-values.
721
    :return: The axes on which the plot was drawn.
722
    :rtype: :class:`matplotlib.axes.Axes`
723
    """
724
725
    if ax is None:
726
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
727
728
    ax.set_title(title, fontsize=title_fontsize)
729
730
    costs = {_titlify(group): cost
731
             for group, (_, _, cost, *_) in thresholds_data.items()}
732
733
    (pd.Series(costs)
734
     .sort_values(ascending=False)
735
     .plot(kind='barh', ax=ax))
736
737
    ax.set_xlabel('Cost', fontsize=text_fontsize)
738
    ax.set_ylabel('Threshold', fontsize=text_fontsize)
739
    ax.tick_params(labelsize=text_fontsize)
740
741
    return ax
742
743
744
def plot_thresholds(thresholds_data,
745
                    markersize=7,
746
                    title='Thresholds by Strategy and Attribute',
747
                    xlim=None,
748
                    ax=None, figsize=None,
749
                    title_fontsize='large', text_fontsize='medium'):
750
    """Plot thresholds by strategy and by attribute.
751
752
    Based on :func:`skplt.metrics.plot_roc`
753
754
    :param thresholds_data: Thresholds by attribute from the
755
                            function
756
                            :func:`~responsibly.interventions
757
                            .threshold.find_thresholds`.
758
    :type thresholds_data: dict
759
    :param int markersize: Marker size.
760
    :param str title: Title of the generated plot.
761
    :param tuple xlim: Set the data limits for the x-axis.
762
    :param ax: The axes upon which to plot the curve.
763
               If `None`, the plot is drawn on a new set of axes.
764
    :param tuple figsize: Tuple denoting figure size of the plot
765
                          e.g. (6, 6).
766
    :param title_fontsize: Matplotlib-style fontsizes.
767
                          Use e.g. 'small', 'medium', 'large'
768
                          or integer-values.
769
    :param text_fontsize: Matplotlib-style fontsizes.
770
                          Use e.g. 'small', 'medium', 'large'
771
                          or integer-values.
772
    :return: The axes on which the plot was drawn.
773
    :rtype: :class:`matplotlib.axes.Axes`
774
    """
775
776
    if ax is None:
777
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
778
779
    ax.set_title(title, fontsize=title_fontsize)
780
781
    # TODO: refactor!
782
    df = pd.DataFrame({_titlify(key): thresholds
783
                       for key, (thresholds, *_) in thresholds_data.items()
784
                       if key != 'separation'})
785
    melted_df = pd.melt(df, var_name='Strategy', value_name='Threshold')
786
    melted_df['Attribute'] = list(df.index) * len(df.columns)
787
788
    sns.stripplot(y='Strategy', x='Threshold', hue='Attribute', data=melted_df,
789
                  jitter=False, dodge=True, size=markersize, ax=ax)
790
791
    minor_locator = AutoMinorLocator(2)
792
    fig.gca().yaxis.set_minor_locator(minor_locator)
0 ignored issues
show
introduced by
The variable fig does not seem to be defined in case ax is None on line 776 is False. Are you sure this can never be the case?
Loading history...
793
    ax.grid(which='minor')
794
795
    if xlim is not None:
796
        ax.set_xlim(*xlim)
797
798
    ax.set_xlabel('Threshold', fontsize=text_fontsize)
799
    ax.set_ylabel('Strategy', fontsize=text_fontsize)
800
    ax.tick_params(labelsize=text_fontsize)
801
802
    return ax
803