Passed
Pull Request — master (#14)
by Shlomi
02:03
created

ethically.fairness.metrics.visualization   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 124
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 57
dl 0
loc 124
rs 10
c 0
b 0
f 0
wmc 8

4 Functions

Rating   Name   Duplication   Size   Complexity  
A _groupby() 0 5 2
A plot_roc_curves() 0 54 4
A plot_roc_by_attr() 0 33 1
A distplot_by() 0 14 1
1
from collections import defaultdict
2
3
import seaborn as sns
4
from matplotlib import pylab as plt
5
6
from ethically.fairness.metrics.score import (
7
    roc_auc_score_by_attr, roc_curve_by_attr,
8
)
9
10
11
def _groupby(x, by):
12
    d = defaultdict(list)
13
    for key, val in zip(by, x):
14
        d[key].append(val)
15
    return d
16
17
18
def distplot_by(a, by, bins=None, hist=True, kde=True, rug=False,
19
                fit=None, hist_kws=None, kde_kws=None, rug_kws=None,
20
                fit_kws=None, vertical=False, norm_hist=False,
21
                ax=None):
22
23
    axes = [sns.distplot(a_group,
24
                         bins=bins, hist=hist, kde=kde, rug=rug,
25
                         fit=fit, hist_kws=hist_kws, kde_kws=kde_kws,
26
                         rug_kws=rug_kws, fit_kws=fit_kws,
27
                         vertical=vertical, norm_hist=norm_hist,
28
                         ax=ax, label=group)
29
            for group, a_group in _groupby(a, by).items()]
30
    plt.legend()
31
    return axes
32
33
34
# Soruce: https://github.com/reiinakano/scikit-plot/blob/master/scikitplot/metrics.py#L332
35
def plot_roc_curves(roc_curves, aucs=None,
36
                    title='ROC Curves by Attribute',
37
                    ax=None, figsize=None,
38
                    title_fontsize='large', text_fontsize='medium'):
39
    """Generate the ROC curves by attribute from (fpr, tpr, thresholds).
40
41
    Based on :func:`skplt.metrics.plot_roc`
42
43
    :param roc_curves: Receiver operating characteristic (ROC)
44
                       by attribute.
45
    :type roc_curves: dict
46
    :param aucs: Area Under the ROC (AUC) by attribute.
47
    :type aucs: dict
48
    :param str title: Title of the generated plot.
49
    :param ax: The axes upon which to plot the curve.
50
               If `None`, the plot is drawn on a new set of axes.
51
    :param tuple figsize: Tuple denoting figure size of the plot
52
                          e.g. (6, 6).
53
    :param title_fontsize: Matplotlib-style fontsizes.
54
                          Use e.g. 'small', 'medium', 'large'
55
                          or integer-values.
56
    :param text_fontsize: Matplotlib-style fontsizes.
57
                          Use e.g. 'small', 'medium', 'large'
58
                          or integer-values.
59
    :return: The axes on which the plot was drawn.
60
    :rtype: :class:`matplotlib.axes.Axes`
61
62
    """
63
64
    if ax is None:
65
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
66
67
    ax.set_title(title, fontsize=title_fontsize)
68
69
    for x_sens_value in roc_curves:
70
71
        label = 'ROC curve of group {0}'.format(x_sens_value)
72
        if aucs is not None:
73
            label += ' (area = {:0.2f})'.format(aucs[x_sens_value])
74
75
        ax.plot(roc_curves[x_sens_value][0],
76
                roc_curves[x_sens_value][1],
77
                lw=2,
78
                label=label)
79
80
    ax.plot([0, 1], [0, 1], 'k--', lw=2)
81
    ax.set_xlim([0.0, 1.0])
82
    ax.set_ylim([0.0, 1.05])
83
    ax.set_xlabel('False Positive Rate', fontsize=text_fontsize)
84
    ax.set_ylabel('True Positive Rate', fontsize=text_fontsize)
85
    ax.tick_params(labelsize=text_fontsize)
86
    ax.legend(loc='lower right', fontsize=text_fontsize)
87
88
    return ax
89
90
91
def plot_roc_by_attr(y_true, y_score, x_sens,
92
                     title='ROC Curves by Attribute',
93
                     ax=None, figsize=None,
94
                     title_fontsize='large', text_fontsize='medium'):
95
    """Generate the ROC curves by attribute from targets and scores.
96
97
    Based on :func:`skplt.metrics.plot_roc`
98
99
    :param y_true: Binary ground truth (correct) target values.
100
    :param y_score: Estimated target score as returned by a classifier.
101
    :param x_sens: Sensitive attribute values corresponded to each
102
                   estimated target.
103
    :param str title: Title of the generated plot.
104
    :param ax: The axes upon which to plot the curve.
105
               If `None`, the plot is drawn on a new set of axes.
106
    :param tuple figsize: Tuple denoting figure size of the plot
107
                          e.g. (6, 6).
108
    :param title_fontsize: Matplotlib-style fontsizes.
109
                          Use e.g. 'small', 'medium', 'large'
110
                          or integer-values.
111
    :param text_fontsize: Matplotlib-style fontsizes.
112
                          Use e.g. 'small', 'medium', 'large'
113
                          or integer-values.
114
    :return: The axes on which the plot was drawn.
115
    :rtype: :class:`matplotlib.axes.Axes`
116
117
    """
118
119
    roc_curves = roc_curve_by_attr(y_true, y_score, x_sens)
120
    aucs = roc_auc_score_by_attr(y_true, y_score, x_sens)
121
    return plot_roc_curves(roc_curves, aucs,
122
                           title, ax, figsize,
123
                           title_fontsize, text_fontsize)
124