GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Passed
Push — master ( e1a880...f71d41 )
by Andreas
01:26
created

klib.describe.dist_plot()   C

Complexity

Conditions 9

Size

Total Lines 123
Code Lines 62

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 62
nop 11
dl 0
loc 123
rs 5.9103
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
'''
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
'''
7
8
# Imports
9
import matplotlib.pyplot as plt
10
import matplotlib.ticker as ticker
11
import numpy as np
12
import pandas as pd
13
import scipy
14
import seaborn as sns
15
16
from .utils import _corr_selector
17
from .utils import _missing_vals
18
from .utils import _validate_input_0_1
19
from .utils import _validate_input_bool
20
21
22
# Functions
23
24
# Correlation Matrix
25
def corr_mat(data, split=None, threshold=0, method='pearson'):
26
    '''
27
    Parameters
28
    ----------
29
30
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
31
    information is used to label the plots.
32
33
    split: {None, 'pos', 'neg', 'high', 'low'}, default None
34
        Type of split to be performed.
35
36
    threshold: float, default 0
37
        Value between 0 <= threshold <= 1
38
39
    method: {'pearson', 'spearman', 'kendall'}, default 'pearson'
40
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
41
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
42
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but
43
                    more robus in smaller dataets than 'spearman'.
44
45
    Returns
46
    -------
47
    returns a Pandas Styler object
48
    '''
49
50
    # Validate Inputs
51
    _validate_input_0_1(threshold, 'threshold')
52
53
    def color_negative_red(val):
54
        color = '#FF3344' if val < 0 else None
55
        return 'color: %s' % color
56
57
    data = pd.DataFrame(data)
58
    corr = data.corr(method=method)
59
60
    corr = _corr_selector(corr, split=split, threshold=threshold)
61
62
    return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep='-')
63
64
65
# Correlation matrix / heatmap
66
def corr_plot(data, split=None, threshold=0, target=None, method='pearson', cmap='BrBG', figsize=(12, 10), annot=True,
67
              dev=False, **kwargs):
68
    '''
69
    Two-dimensional visualization of the correlation between feature-columns, excluding NA values.
70
71
    Parameters
72
    ----------
73
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
74
    information is used to label the plots.
75
76
    split: {None, 'pos', 'neg', 'high', 'low'}, default None
77
        Type of split to be performed.
78
79
        * None: visualize all correlations between the feature-columns.
80
        * pos: visualize all positive correlations between the feature-columns above the threshold.
81
        * neg: visualize all negative correlations between the feature-columns below the threshold.
82
        * high: visualize all correlations between the feature-columns for which abs(corr) > threshold is True.
83
        * low: visualize all correlations between the feature-columns for which abs(corr) < threshold is True.
84
85
    threshold: float, default 0
86
        Value between 0 <= threshold <= 1
87
88
    target: string, list, np.array or pd.Series, default None
89
        Specify target for correlation. E.g. label column to generate only the correlations between each feature\
90
        and the label.
91
92
    method: {'pearson', 'spearman', 'kendall'}, default 'pearson'
93
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
94
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
95
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but
96
                   more robust in smaller dataets than 'spearman'.
97
98
    cmap: matplotlib colormap name or object, or list of colors, default 'BrBG'
99
        The mapping from data values to color space.
100
101
    figsize: tuple, default (12, 10)
102
        Use to control the figure size.
103
104
    annot: bool, default True
105
        Use to show or hide annotations.
106
107
    dev: bool, default False
108
        Display figure settings in the plot by setting dev = True. If False, the settings are not displayed.s
109
110
    **kwargs: optional
111
        Additional elements to control the visualization of the plot, e.g.:
112
113
        * mask: bool, default True
114
        If set to False the entire correlation matrix, including the upper triangle is shown. Set dev = False in this \
115
        case to avoid overlap.
116
        * vmax: float, default is calculated from the given correlation coefficients.
117
        Value between -1 or vmin <= vmax <= 1, limits the range of the colorbar.
118
        * vmin: float, default is calculated from the given correlation coefficients.
119
        Value between -1 <= vmin <= 1 or vmax, limits the range of the colorbar.
120
        * linewidths: float, default 0.5
121
        Controls the line-width inbetween the squares.
122
        * annot_kws: dict, default {'size' : 10}
123
        Controls the font size of the annotations. Only available when annot = True.
124
        * cbar_kws: dict, default {'shrink': .95, 'aspect': 30}
125
        Controls the size of the colorbar.
126
        * Many more kwargs are available, i.e. 'alpha' to control blending, or options to adjust labels, ticks ...
127
128
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
129
130
    Returns
131
    -------
132
    ax: matplotlib Axes
133
        Returns the Axes object with the plot for further tweaking.
134
    '''
135
136
    # Validate Inputs
137
    _validate_input_0_1(threshold, 'threshold')
138
    _validate_input_bool(annot, 'annot')
139
    _validate_input_bool(dev, 'dev')
140
141
    data = pd.DataFrame(data)
142
143
    # Obtain correlations
144
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
145
        target_data = []
146
        if isinstance(target, str):
147
            target_data = data[target]
148
            data = data.drop(target, axis=1)
149
150
        elif isinstance(target, (list, pd.Series, np.ndarray)):
151
            target_data = pd.Series(target)
152
153
        corr = pd.DataFrame(data.corrwith(target_data))
154
        corr.rename_axis(target, axis=1, inplace=True)
155
        corr = _corr_selector(corr, split=split, threshold=threshold)
156
        corr = corr.sort_values(corr.columns[0], ascending=False)
157
        vmax = np.round(np.nanmax(corr)-0.05, 2)
158
        vmin = np.round(np.nanmin(corr)+0.05, 2)
159
        mask = False
160
        square = False
161
162
    else:
163
        corr = corr_mat(data, split=split, threshold=threshold, method=method).data
164
165
        mask = np.triu(np.ones_like(corr, dtype=np.bool))  # Generate mask for the upper triangle
166
        square = True
167
168
        vmax = np.round(np.nanmax(corr.where(~mask))-0.05, 2)
169
        vmin = np.round(np.nanmin(corr.where(~mask))+0.05, 2)
170
171
    fig, ax = plt.subplots(figsize=figsize)
172
173
    # Specify kwargs for the heatmap
174
    kwargs = {'mask': mask,
175
              'cmap': cmap,
176
              'annot': annot,
177
              'vmax': vmax,
178
              'vmin': vmin,
179
              'linewidths': .5,
180
              'annot_kws': {'size': 10},
181
              'cbar_kws': {'shrink': .95, 'aspect': 30},
182
              **kwargs}
183
184
    # Draw heatmap with mask and some default settings
185
    sns.heatmap(corr,
186
                center=0,
187
                square=square,
188
                fmt='.2f',
189
                **kwargs
190
                )
191
192
    ax.set_title(f'Feature-correlation ({method})', fontdict={'fontsize': 18})
193
194
    # Display settings
195
    if dev:
196
        fig.suptitle(f"\
197
            Settings (dev-mode): \n\
198
            - split-mode: {split} \n\
199
            - threshold: {threshold} \n\
200
            - method: {method} \n\
201
            - annotations: {annot} \n\
202
            - cbar: \n\
203
                - vmax: {vmax} \n\
204
                - vmin: {vmin} \n\
205
            - linewidths: {kwargs['linewidths']} \n\
206
            - annot_kws: {kwargs['annot_kws']} \n\
207
            - cbar_kws: {kwargs['cbar_kws']}",
208
                     fontsize=12,
209
                     color='gray',
210
                     x=0.35,
211
                     y=0.85,
212
                     ha='left')
213
214
    return ax
215
216
217
# Distribution plot
218
def dist_plot(data, mean_color='orange', figsize=(14, 2), fill_range=(0.025, 0.975), hist=False, bins=None,
219
              showall=False, kde_kws=None, rug_kws=None, fill_kws=None, font_kws=None):
220
    '''
221
    Two-dimensional visualization of the distribution of numerical features.
222
223
    Parameters
224
    ----------
225
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
226
    information is used to label the plots.
227
228
    mean_color: any valid color, default 'orange'
229
        Color of the vertical line indicating the mean of the data.
230
231
    figsize: tuple, default (14, 2)
232
        Use to control the figure size.
233
234
    fill_range: tuple, default (0.025, 0.975)
235
        Use to control set the quantiles for shading. Default spans 95% of the data, which is about two std. deviations\
236
        above and below the mean.
237
238
    hist: bool, default False
239
        Set to True to display histogram bars in the plot.
240
241
    bins: integer, default None
242
        Specification of the number of hist bins. Requires hist = True
243
244
    showall: bool, default False
245
        Set to True to remove the output limit of 20 plots.
246
247
    kdw_kws: dict, default None
248
        Keyword arguments for kdeplot().
249
250
    rug_kws: dict, default None
251
        Keyword arguments for rugplot().
252
253
    fill_kws: dict, default None
254
        Keyword arguments to control the fill.
255
256
    font_kws: dict, default None
257
        Keyword arguments to control the font.
258
259
    Returns
260
    -------
261
    ax: matplotlib Axes
262
        Returns the Axes object with the plot for further tweaking.
263
    '''
264
265
    # Validate Inputs
266
    _validate_input_bool(hist, 'hist')
267
    _validate_input_bool(showall, 'showall')
268
269
    # Handle dictionary defaults
270
    kde_kws = {} if kde_kws is None else kde_kws.copy()
271
    rug_kws = {} if rug_kws is None else rug_kws.copy()
272
    fill_kws = {} if fill_kws is None else fill_kws.copy()
273
    font_kws = {} if font_kws is None else font_kws.copy()
274
275
    data = pd.DataFrame(data).copy()
276
    cols = list(data.select_dtypes(include=['number']).columns)  # numeric cols
277
278
    if len(cols) == 0:
279
        print('No columns with numeric data were detected.')
280
    elif len(cols) >= 20 and showall is False:
281
        print(
282
            f'Note: The number of numerical features is very large ({len(cols)}), please consider splitting the data.\
283
            Showing plots for the first 20 numerical features. Override this by setting showall=True.')
284
        cols = cols[:20]
285
286
    # Default settings
287
    kde_kws = {'color': 'k', 'alpha': 0.7, 'linewidth': 1, **kde_kws}
288
    rug_kws = {'color': 'brown', 'alpha': 0.5, 'linewidth': 2, 'height': 0.04, **rug_kws}
289
    fill_kws = {'color': 'brown', 'alpha': 0.1, **fill_kws}
290
    font_kws = {'color':  '#111111', 'weight': 'normal', 'size': 11, **font_kws}
291
292
    ax = []
293
    for col in cols:
294
        _, ax = plt.subplots(figsize=figsize)
295
        ax = sns.distplot(data[col], bins=bins, hist=hist, rug=True, kde_kws=kde_kws,
296
                          rug_kws=rug_kws, hist_kws={'alpha': 0.5, 'histtype': 'step'})
297
298
        # Vertical lines and fill
299
        line = ax.lines[0]
300
        x = line.get_xydata()[:, 0]
301
        y = line.get_xydata()[:, 1]
302
        ax.fill_between(x, y,
303
                        where=(
304
                            (x >= np.quantile(data[col], fill_range[0])) &
305
                            (x <= np.quantile(data[col], fill_range[1]))),
306
                        label=f'{fill_range[0]*100:.0f}% - {fill_range[1]*100:.0f}%',
307
                        **fill_kws)
308
309
        ax.vlines(x=np.mean(data[col]),
310
                  ymin=0,
311
                  ymax=np.interp(np.mean(data[col]), x, y),
312
                  ls='dotted', color=mean_color, lw=2, label='mean')
313
        ax.vlines(x=np.median(data[col]),
314
                  ymin=0,
315
                  ymax=np.interp(np.median(data[col]), x, y),
316
                  ls=':', color='.3', label='median')
317
        ax.vlines(x=np.quantile(data[col], 0.25),
318
                  ymin=0,
319
                  ymax=np.interp(np.quantile(data[col], 0.25), x, y), ls=':', color='.5', label='25%')
320
        ax.vlines(x=np.quantile(data[col], 0.75),
321
                  ymin=0,
322
                  ymax=np.interp(np.quantile(data[col], 0.75), x, y), ls=':', color='.5', label='75%')
323
324
        ax.set_ylim(0,)
325
        ax.set_xlim(ax.get_xlim()[0]*1.1, ax.get_xlim()[1]*1.1)
326
327
        # Annotations and legend
328
        ax.text(0.01, 0.85, f'Mean: {np.round(np.mean(data[col]),2)}',
329
                fontdict=font_kws, transform=ax.transAxes)
330
        ax.text(0.01, 0.7, f'Std. dev: {np.round(scipy.stats.tstd(data[col]),2)}',
331
                fontdict=font_kws, transform=ax.transAxes)
332
        ax.text(0.01, 0.55, f'Skew: {np.round(scipy.stats.skew(data[col]),2)}',
333
                fontdict=font_kws, transform=ax.transAxes)
334
        ax.text(0.01, 0.4, f'Kurtosis: {np.round(scipy.stats.kurtosis(data[col]),2)}',  # Excess Kurtosis
335
                fontdict=font_kws, transform=ax.transAxes)
336
        ax.text(0.01, 0.25, f'Count: {np.round(len(data[col]))}',
337
                fontdict=font_kws, transform=ax.transAxes)
338
        ax.legend(loc='upper right')
339
340
    return ax
341
342
343
# Missing value plot
344
def missingval_plot(data, cmap='PuBuGn', figsize=(12, 12), sort=False, spine_color='#EEEEEE'):
345
    '''
346
    Two-dimensional visualization of the missing values in a dataset.
347
348
    Parameters
349
    ----------
350
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
351
    information is used to label the plots.
352
353
    cmap: colormap, default 'PuBuGn'
354
        Any valid colormap can be used. E.g. 'Greys', 'RdPu'. More information can be found in the matplotlib \
355
        documentation.
356
357
    figsize: tuple, default (20, 12)
358
        Use to control the figure size.
359
360
    sort: bool, default False
361
        Sort columns based on missing values in descending order and drop columns without any missing values
362
363
    spine_color: color-code, default '#EEEEEE'
364
        Set to 'None' to hide the spines on all plots or use any valid matplotlib color argument.
365
366
    Returns
367
    -------
368
    figure
369
    '''
370
371
    data = pd.DataFrame(data)
372
373
    if sort:
374
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
375
        final_cols = mv_cols_sorted.drop(mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()).keys().tolist()
376
        data = data[final_cols]
377
        print('Displaying only columns with missing values.')
378
379
    # Identify missing values
380
    mv_cols = _missing_vals(data)['mv_cols']  # data.isna().sum(axis=0)
381
    mv_rows = _missing_vals(data)['mv_rows']  # data.isna().sum(axis=1)
382
    mv_total = _missing_vals(data)['mv_total']
383
    mv_cols_ratio = _missing_vals(data)['mv_cols_ratio']  # mv_cols / data.shape[0]
384
    total_datapoints = data.shape[0]*data.shape[1]
385
386
    if mv_total == 0:
387
        print('No missing values found in the dataset.')
388
    else:
389
        # Create figure and axes
390
        fig = plt.figure(figsize=figsize)
391
        gs = fig.add_gridspec(nrows=6, ncols=6, left=0.05, wspace=0.05)
392
        ax1 = fig.add_subplot(gs[:1, :5])
393
        ax2 = fig.add_subplot(gs[1:, :5])
394
        ax3 = fig.add_subplot(gs[:1, 5:])
395
        ax4 = fig.add_subplot(gs[1:, 5:])
396
397
        # ax1 - Barplot
398
        colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
399
        ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio)*100, 2), color=colors)
400
        ax1.get_xaxis().set_visible(False)
401
        ax1.set(frame_on=False, xlim=(-.5, len(mv_cols)-0.5))
402
        ax1.set_ylim(0, np.max(mv_cols_ratio)*100)
403
        ax1.grid(linestyle=':', linewidth=1)
404
        ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0))
405
        ax1.tick_params(axis='y', colors='#111111', length=1)
406
407
        # annotate values on top of the bars
408
        for rect, label in zip(ax1.patches, mv_cols):
409
            height = rect.get_height()
410
            ax1.text(.1 + rect.get_x() + rect.get_width() / 2, height+0.5, label,
411
                     ha='center',
412
                     va='bottom',
413
                     rotation='90',
414
                     alpha=0.5,
415
                     fontsize='small')
416
417
        ax1.set_frame_on(True)
418
        for _, spine in ax1.spines.items():
419
            spine.set_visible(True)
420
            spine.set_color(spine_color)
421
        ax1.spines['top'].set_color(None)
422
423
        # ax2 - Heatmap
424
        sns.heatmap(data.isna(), cbar=False, cmap='binary', ax=ax2)
425
        ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1))
426
        ax2.set_yticklabels(ax2.get_yticks())
427
        ax2.set_xticklabels(
428
            ax2.get_xticklabels(),
429
            horizontalalignment='center',
430
            fontweight='light',
431
            fontsize='medium')
432
        ax2.tick_params(length=1, colors='#111111')
433
        for _, spine in ax2.spines.items():
434
            spine.set_visible(True)
435
            spine.set_color(spine_color)
436
437
        # ax3 - Summary
438
        fontax3 = {'color':  '#111111',
439
                   'weight': 'normal',
440
                   'size': 12,
441
                   }
442
        ax3.get_xaxis().set_visible(False)
443
        ax3.get_yaxis().set_visible(False)
444
        ax3.set(frame_on=False)
445
446
        ax3.text(0.1, 0.9, f"Total: {np.round(total_datapoints/1000,1)}K",
447
                 transform=ax3.transAxes,
448
                 fontdict=fontax3)
449
        ax3.text(0.1, 0.7, f"Missing: {np.round(mv_total/1000,1)}K",
450
                 transform=ax3.transAxes,
451
                 fontdict=fontax3)
452
        ax3.text(0.1, 0.5, f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
453
                 transform=ax3.transAxes,
454
                 fontdict=fontax3)
455
        ax3.text(0.1, 0.3, f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
456
                 transform=ax3.transAxes,
457
                 fontdict=fontax3)
458
        ax3.text(0.1, 0.1, f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
459
                 transform=ax3.transAxes,
460
                 fontdict=fontax3)
461
462
        # ax4 - Scatter plot
463
        ax4.get_yaxis().set_visible(False)
464
        for _, spine in ax4.spines.items():
465
            spine.set_color(spine_color)
466
        ax4.tick_params(axis='x', colors='#111111', length=1)
467
468
        ax4.scatter(mv_rows, range(len(mv_rows)), s=mv_rows, c=mv_rows, cmap=cmap, marker=".", vmin=1)
469
        ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
470
        ax4.set_xlim(0, max(mv_rows)+0.5)
471
        ax4.grid(linestyle=':', linewidth=1)
472
473
        gs.figure.suptitle('Missing value plot', x=0.45, y=0.94, fontsize=18, color='#111111')
474
475
        return gs
476