Passed
Pull Request — master (#14)
by Shlomi
02:03
created

distplot_by()   A

Complexity

Conditions 1

Size

Total Lines 14
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 13
nop 14
dl 0
loc 14
rs 9.75
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from collections import defaultdict
2
3
import seaborn as sns
4
from matplotlib import pylab as plt
5
6
from ethically.fairness.metrics.score import (
7
    roc_auc_score_by_attr, roc_curve_by_attr,
8
)
9
10
11
def _groupby(x, by):
12
    d = defaultdict(list)
13
    for key, val in zip(by, x):
14
        d[key].append(val)
15
    return d
16
17
18
def distplot_by(a, by, bins=None, hist=True, kde=True, rug=False,
19
                fit=None, hist_kws=None, kde_kws=None, rug_kws=None,
20
                fit_kws=None, vertical=False, norm_hist=False,
21
                ax=None):
22
23
    axes = [sns.distplot(a_group,
24
                         bins=bins, hist=hist, kde=kde, rug=rug,
25
                         fit=fit, hist_kws=hist_kws, kde_kws=kde_kws,
26
                         rug_kws=rug_kws, fit_kws=fit_kws,
27
                         vertical=vertical, norm_hist=norm_hist,
28
                         ax=ax, label=group)
29
            for group, a_group in _groupby(a, by).items()]
30
    plt.legend()
31
    return axes
32
33
34
# Soruce: https://github.com/reiinakano/scikit-plot/blob/master/scikitplot/metrics.py#L332
35
def plot_roc_curves(roc_curves, aucs=None,
36
                    title='ROC Curves by Attribute',
37
                    ax=None, figsize=None,
38
                    title_fontsize='large', text_fontsize='medium'):
39
    """Generate the ROC curves by attribute from (fpr, tpr, thresholds).
40
41
    Based on :func:`skplt.metrics.plot_roc`
42
43
    :param roc_curves: Receiver operating characteristic (ROC)
44
                       by attribute.
45
    :type roc_curves: dict
46
    :param aucs: Area Under the ROC (AUC) by attribute.
47
    :type aucs: dict
48
    :param str title: Title of the generated plot.
49
    :param ax: The axes upon which to plot the curve.
50
               If `None`, the plot is drawn on a new set of axes.
51
    :param tuple figsize: Tuple denoting figure size of the plot
52
                          e.g. (6, 6).
53
    :param title_fontsize: Matplotlib-style fontsizes.
54
                          Use e.g. 'small', 'medium', 'large'
55
                          or integer-values.
56
    :param text_fontsize: Matplotlib-style fontsizes.
57
                          Use e.g. 'small', 'medium', 'large'
58
                          or integer-values.
59
    :return: The axes on which the plot was drawn.
60
    :rtype: :class:`matplotlib.axes.Axes`
61
62
    """
63
64
    if ax is None:
65
        fig, ax = plt.subplots(1, 1, figsize=figsize)  # pylint: disable=unused-variable
66
67
    ax.set_title(title, fontsize=title_fontsize)
68
69
    for x_sens_value in roc_curves:
70
71
        label = 'ROC curve of group {0}'.format(x_sens_value)
72
        if aucs is not None:
73
            label += ' (area = {:0.2f})'.format(aucs[x_sens_value])
74
75
        ax.plot(roc_curves[x_sens_value][0],
76
                roc_curves[x_sens_value][1],
77
                lw=2,
78
                label=label)
79
80
    ax.plot([0, 1], [0, 1], 'k--', lw=2)
81
    ax.set_xlim([0.0, 1.0])
82
    ax.set_ylim([0.0, 1.05])
83
    ax.set_xlabel('False Positive Rate', fontsize=text_fontsize)
84
    ax.set_ylabel('True Positive Rate', fontsize=text_fontsize)
85
    ax.tick_params(labelsize=text_fontsize)
86
    ax.legend(loc='lower right', fontsize=text_fontsize)
87
88
    return ax
89
90
91
def plot_roc_by_attr(y_true, y_score, x_sens,
92
                     title='ROC Curves by Attribute',
93
                     ax=None, figsize=None,
94
                     title_fontsize='large', text_fontsize='medium'):
95
    """Generate the ROC curves by attribute from targets and scores.
96
97
    Based on :func:`skplt.metrics.plot_roc`
98
99
    :param y_true: Binary ground truth (correct) target values.
100
    :param y_score: Estimated target score as returned by a classifier.
101
    :param x_sens: Sensitive attribute values corresponded to each
102
                   estimated target.
103
    :param str title: Title of the generated plot.
104
    :param ax: The axes upon which to plot the curve.
105
               If `None`, the plot is drawn on a new set of axes.
106
    :param tuple figsize: Tuple denoting figure size of the plot
107
                          e.g. (6, 6).
108
    :param title_fontsize: Matplotlib-style fontsizes.
109
                          Use e.g. 'small', 'medium', 'large'
110
                          or integer-values.
111
    :param text_fontsize: Matplotlib-style fontsizes.
112
                          Use e.g. 'small', 'medium', 'large'
113
                          or integer-values.
114
    :return: The axes on which the plot was drawn.
115
    :rtype: :class:`matplotlib.axes.Axes`
116
117
    """
118
119
    roc_curves = roc_curve_by_attr(y_true, y_score, x_sens)
120
    aucs = roc_auc_score_by_attr(y_true, y_score, x_sens)
121
    return plot_roc_curves(roc_curves, aucs,
122
                           title, ax, figsize,
123
                           title_fontsize, text_fontsize)
124