GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Passed
Push — master ( 4f98db...92a4ba )
by Andreas
01:40
created

klib.describe.corr_mat()   A

Complexity

Conditions 2

Size

Total Lines 38
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 4
dl 0
loc 38
rs 9.95
c 0
b 0
f 0
1
'''
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
'''
7
8
# Imports
9
import matplotlib.pyplot as plt
10
import matplotlib.ticker as ticker
11
import numpy as np
12
import pandas as pd
13
import scipy
14
import seaborn as sns
15
16
from .utils import _corr_selector
17
from .utils import _missing_vals
18
from .utils import _validate_input_0_1
19
from .utils import _validate_input_bool
20
21
22
# Functions
23
24
# Correlation Matrix
25
def corr_mat(data, split=None, threshold=0, method='pearson'):
26
    '''
27
    Parameters
28
    ----------
29
30
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
31
    information is used to label the plots.
32
33
    split: {None, 'pos', 'neg', 'high', 'low'}, default None
34
        Type of split to be performed.
35
36
    threshold: float, default 0
37
        Value between 0 <= threshold <= 1
38
39
    method: {'pearson', 'spearman', 'kendall'}, default 'pearson'
40
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
41
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
42
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but
43
                    more robus in smaller dataets than 'spearman'.
44
45
    Returns
46
    -------
47
    returns a Pandas Styler object
48
    '''
49
50
    # Validate Inputs
51
    _validate_input_0_1(threshold, 'threshold')
52
53
    def color_negative_red(val):
54
        color = '#FF3344' if val < 0 else None
55
        return 'color: %s' % color
56
57
    data = pd.DataFrame(data)
58
    corr = data.corr(method=method)
59
60
    corr = _corr_selector(corr, split=split, threshold=threshold)
61
62
    return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep='-')
63
64
65
# Correlation matrix / heatmap
66
def corr_plot(data, split=None, threshold=0, target=None, method='pearson', cmap='BrBG', figsize=(12, 10), annot=True,
67
              dev=False, **kwargs):
68
    '''
69
    Two-dimensional visualization of the correlation between feature-columns, excluding NA values.
70
71
    Parameters
72
    ----------
73
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
74
    information is used to label the plots.
75
76
    split: {None, 'pos', 'neg', 'high', 'low'}, default None
77
        Type of split to be performed.
78
79
        * None: visualize all correlations between the feature-columns.
80
        * pos: visualize all positive correlations between the feature-columns above the threshold.
81
        * neg: visualize all negative correlations between the feature-columns below the threshold.
82
        * high: visualize all correlations between the feature-columns for which abs(corr) > threshold is True.
83
        * low: visualize all correlations between the feature-columns for which abs(corr) < threshold is True.
84
85
    threshold: float, default 0
86
        Value between 0 <= threshold <= 1
87
88
    target: string, list, np.array or pd.Series, default None
89
        Specify target for correlation. E.g. label column to generate only the correlations between each feature\
90
        and the label.
91
92
    method: {'pearson', 'spearman', 'kendall'}, default 'pearson'
93
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
94
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
95
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but
96
                   more robust in smaller dataets than 'spearman'.
97
98
    cmap: matplotlib colormap name or object, or list of colors, default 'BrBG'
99
        The mapping from data values to color space.
100
101
    figsize: tuple, default (12, 10)
102
        Use to control the figure size.
103
104
    annot: bool, default True
105
        Use to show or hide annotations.
106
107
    dev: bool, default False
108
        Display figure settings in the plot by setting dev = True. If False, the settings are not displayed.s
109
110
    **kwargs: optional
111
        Additional elements to control the visualization of the plot, e.g.:
112
113
        * mask: bool, default True
114
        If set to False the entire correlation matrix, including the upper triangle is shown. Set dev = False in this \
115
        case to avoid overlap.
116
        * vmax: float, default is calculated from the given correlation coefficients.
117
        Value between -1 or vmin <= vmax <= 1, limits the range of the colorbar.
118
        * vmin: float, default is calculated from the given correlation coefficients.
119
        Value between -1 <= vmin <= 1 or vmax, limits the range of the colorbar.
120
        * linewidths: float, default 0.5
121
        Controls the line-width inbetween the squares.
122
        * annot_kws: dict, default {'size' : 10}
123
        Controls the font size of the annotations. Only available when annot = True.
124
        * cbar_kws: dict, default {'shrink': .95, 'aspect': 30}
125
        Controls the size of the colorbar.
126
        * Many more kwargs are available, i.e. 'alpha' to control blending, or options to adjust labels, ticks ...
127
128
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
129
130
    Returns
131
    -------
132
    ax: matplotlib Axes
133
        Returns the Axes object with the plot for further tweaking.
134
    '''
135
136
    # Validate Inputs
137
    _validate_input_0_1(threshold, 'threshold')
138
    _validate_input_bool(annot, 'annot')
139
    _validate_input_bool(dev, 'dev')
140
141
    data = pd.DataFrame(data)
142
    mask = False
143
    square = False
144
145
    # Obtain correlations
146
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
147
        if isinstance(target, str):
148
            target_data = data[target]
149
            data = data.drop(target, axis=1)
150
151
        elif isinstance(target, (list, pd.Series, np.ndarray)):
152
            target_data = pd.Series(target)
153
154
        corr = pd.DataFrame(data.corrwith(target_data))
0 ignored issues
show
introduced by
The variable target_data does not seem to be defined for all execution paths.
Loading history...
155
        corr = _corr_selector(corr, split=split, threshold=threshold)
156
        vmax = np.round(np.nanmax(corr)-0.05, 2)
157
        vmin = np.round(np.nanmin(corr)+0.05, 2)
158
159
    else:
160
        corr = corr_mat(data, split=split, threshold=threshold, method=method).data
161
162
        mask = np.triu(np.ones_like(corr, dtype=np.bool))  # Generate mask for the upper triangle
163
        square = True
164
165
        vmax = np.round(np.nanmax(corr.where(~mask))-0.05, 2)
166
        vmin = np.round(np.nanmin(corr.where(~mask))+0.05, 2)
167
168
    fig, ax = plt.subplots(figsize=figsize)
169
170
    # Specify kwargs for the heatmap
171
    kwargs = {'mask': mask,
172
              'cmap': cmap,
173
              'annot': annot,
174
              'vmax': vmax,
175
              'vmin': vmin,
176
              'linewidths': .5,
177
              'annot_kws': {'size': 10},
178
              'cbar_kws': {'shrink': .95, 'aspect': 30},
179
              **kwargs}
180
181
    # Draw heatmap with mask and some default settings
182
    sns.heatmap(corr,
183
                center=0,
184
                square=square,
185
                fmt='.2f',
186
                **kwargs
187
                )
188
189
    ax.set_title(f'Feature-correlation ({method})', fontdict={'fontsize': 18})
190
191
    # Display settings
192
    if dev:
193
        fig.suptitle(f"\
194
            Settings (dev-mode): \n\
195
            - split-mode: {split} \n\
196
            - threshold: {threshold} \n\
197
            - method: {method} \n\
198
            - annotations: {annot} \n\
199
            - cbar: \n\
200
                - vmax: {vmax} \n\
201
                - vmin: {vmin} \n\
202
            - linewidths: {kwargs['linewidths']} \n\
203
            - annot_kws: {kwargs['annot_kws']} \n\
204
            - cbar_kws: {kwargs['cbar_kws']}",
205
                     fontsize=12,
206
                     color='gray',
207
                     x=0.35,
208
                     y=0.85,
209
                     ha='left')
210
211
    return ax
212
213
214
# Distribution plot
215
def dist_plot(data, mean_color='orange', figsize=(14, 2), fill_range=(0.025, 0.975), hist=False, showall=False,
216
              kde_kws={}, rug_kws={}, fill_kws={}, font_kws={}):
217
    '''
218
    Two-dimensional visualization of the missing values in a dataset.
219
220
    Parameters
221
    ----------
222
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
223
    information is used to label the plots.
224
225
    mean_color: any valid color, default 'orange'
226
        Color of the vertical line indicating the mean of the data.
227
228
    figsize: tuple, default (14, 2)
229
        Use to control the figure size.
230
231
    fill_range: tuple, default (0.025, 0.975)
232
        Use to control set the quantiles for shading. Default spans 95% of the data, which is about two std. deviations\
233
        above and below the mean.
234
235
    hist: bool, default False
236
        Set to True to display histogram bars in the plot.
237
238
    showall: bool, default False
239
        Set to True to remove the output limit of 20 plots.
240
241
    kdw_kws: dict, optional
242
        Keyword arguments for kdeplot().
243
244
    rug_kws: dict, optional
245
        Keyword arguments for rugplot().
246
247
    fill_kws:
248
        Keyword arguments to control the fill.
249
250
    font_kws:
251
        Keyword arguments to control the font.
252
253
    Returns
254
    -------
255
    ax: matplotlib Axes
256
        Returns the Axes object with the plot for further tweaking.
257
    '''
258
259
    # Validate Inputs
260
    _validate_input_bool(hist, 'hist')
261
    _validate_input_bool(showall, 'showall')
262
263
    data = pd.DataFrame(data).copy()
264
    cols = list(data.select_dtypes(include=['number']).columns)  # numeric cols
265
266
    if len(cols) == 0:
267
        print('No columns with numeric data were detected.')
268
    elif len(cols) >= 20 and showall is False:
269
        print(
270
            f'Note: The number of features is very large ({len(cols)}), please consider splitting the data.\
271
            Showing plots for the first 20 numerical features. Override this by setting showall=True.')
272
        cols = cols[:20]
273
274
    # Default settings
275
    kde_kws = {'color': 'k', 'alpha': 0.6, 'linewidth': 1, **kde_kws}
276
    rug_kws = {'color': 'brown', 'alpha': 0.5, 'linewidth': 2, 'height': 0.04, **rug_kws}
277
    fill_kws = {'color': 'brown', 'alpha': 0.1, **fill_kws}
278
    font_kws = {'color':  '#111111', 'weight': 'normal', 'size': 11, **font_kws}
279
280
    ax = []
281
    for col in cols:
282
        fig, ax = plt.subplots(figsize=figsize)
283
        ax = sns.distplot(data[col], hist=hist, rug=True, kde_kws=kde_kws, rug_kws=rug_kws)
284
285
        # Vertical lines and fill
286
        line = ax.lines[0]
287
        x = line.get_xydata()[:, 0]
288
        y = line.get_xydata()[:, 1]
289
        ax.fill_between(x, y,
290
                        where=(
291
                            (x >= np.quantile(data[col], fill_range[0])) &
292
                            (x <= np.quantile(data[col], fill_range[1]))),
293
                        label=f'{fill_range[0]*100:.0f}% - {fill_range[1]*100:.0f}%',
294
                        **fill_kws)
295
296
        ax.vlines(x=np.mean(data[col]),
297
                  ymin=0,
298
                  ymax=np.interp(np.mean(data[col]), x, y),
299
                  ls='dotted', color=mean_color, lw=2, label='mean')
300
        ax.vlines(x=np.median(data[col]),
301
                  ymin=0,
302
                  ymax=np.interp(np.median(data[col]), x, y),
303
                  ls=':', color='.4', label='median')
304
        ax.vlines(x=np.quantile(data[col], 0.25),
305
                  ymin=0,
306
                  ymax=np.interp(np.quantile(data[col], 0.25), x, y), ls=':', color='.6', label='25%')
307
        ax.vlines(x=np.quantile(data[col], 0.75),
308
                  ymin=0,
309
                  ymax=np.interp(np.quantile(data[col], 0.75), x, y), ls=':', color='.6', label='75%')
310
311
        ax.set_ylim(0,)
312
313
        # Annotations and legend
314
        ax.text(0.01, 0.85, f'Mean: {np.round(np.mean(data[col]),2)}',
315
                fontdict=font_kws, transform=ax.transAxes)
316
        ax.text(0.01, 0.7, f'Std. dev: {np.round(scipy.stats.tstd(data[col]),2)}',
317
                fontdict=font_kws, transform=ax.transAxes)
318
        ax.text(0.01, 0.55, f'Skew: {np.round(scipy.stats.skew(data[col]),2)}',
319
                fontdict=font_kws, transform=ax.transAxes)
320
        ax.text(0.01, 0.4, f'Kurtosis: {np.round(scipy.stats.kurtosis(data[col]),2)}',  # Excess Kurtosis
321
                fontdict=font_kws, transform=ax.transAxes)
322
        ax.text(0.01, 0.25, f'Count: {np.round(len(data[col]))}',
323
                fontdict=font_kws, transform=ax.transAxes)
324
        ax.legend(loc='upper right')
325
326
    return ax
327
328
329
# Missing value plot
330
def missingval_plot(data, cmap='PuBuGn', figsize=(12, 12), sort=False, spine_color='#EEEEEE'):
331
    '''
332
    Two-dimensional visualization of the missing values in a dataset.
333
334
    Parameters
335
    ----------
336
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
337
    information is used to label the plots.
338
339
    cmap: colormap, default 'PuBuGn'
340
        Any valid colormap can be used. E.g. 'Greys', 'RdPu'. More information can be found in the matplotlib \
341
        documentation.
342
343
    figsize: tuple, default (20, 12)
344
        Use to control the figure size.
345
346
    sort: bool, default False
347
        Sort columns based on missing values in descending order and drop columns without any missing values
348
349
    spine_color: color-code, default '#EEEEEE'
350
        Set to 'None' to hide the spines on all plots or use any valid matplotlib color argument.
351
352
    Returns
353
    -------
354
    figure
355
    '''
356
357
    data = pd.DataFrame(data)
358
359
    if sort:
360
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
361
        final_cols = mv_cols_sorted.drop(mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()).keys().tolist()
362
        data = data[final_cols]
363
        print('Displaying only columns with missing values.')
364
365
    # Identify missing values
366
    mv_cols = _missing_vals(data)['mv_cols']  # data.isna().sum(axis=0)
367
    mv_rows = _missing_vals(data)['mv_rows']  # data.isna().sum(axis=1)
368
    mv_total = _missing_vals(data)['mv_total']
369
    mv_cols_ratio = _missing_vals(data)['mv_cols_ratio']  # mv_cols / data.shape[0]
370
    total_datapoints = data.shape[0]*data.shape[1]
371
372
    if mv_total == 0:
373
        print('No missing values found in the dataset.')
374
    else:
375
        # Create figure and axes
376
        fig = plt.figure(figsize=figsize)
377
        gs = fig.add_gridspec(nrows=6, ncols=6, left=0.05, wspace=0.05)
378
        ax1 = fig.add_subplot(gs[:1, :5])
379
        ax2 = fig.add_subplot(gs[1:, :5])
380
        ax3 = fig.add_subplot(gs[:1, 5:])
381
        ax4 = fig.add_subplot(gs[1:, 5:])
382
383
        # ax1 - Barplot
384
        colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
385
        ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio)*100, 2), color=colors)
386
        ax1.get_xaxis().set_visible(False)
387
        ax1.set(frame_on=False, xlim=(-.5, len(mv_cols)-0.5))
388
        ax1.set_ylim(0, np.max(mv_cols_ratio)*100)
389
        ax1.grid(linestyle=':', linewidth=1)
390
        ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0))
391
        ax1.tick_params(axis='y', colors='#111111', length=1)
392
393
        # annotate values on top of the bars
394
        for rect, label in zip(ax1.patches, mv_cols):
395
            height = rect.get_height()
396
            ax1.text(.1 + rect.get_x() + rect.get_width() / 2, height+0.5, label,
397
                     ha='center',
398
                     va='bottom',
399
                     rotation='90',
400
                     alpha=0.5,
401
                     fontsize='small')
402
403
        ax1.set_frame_on(True)
404
        for _, spine in ax1.spines.items():
405
            spine.set_visible(True)
406
            spine.set_color(spine_color)
407
        ax1.spines['top'].set_color(None)
408
409
        # ax2 - Heatmap
410
        sns.heatmap(data.isna(), cbar=False, cmap='binary', ax=ax2)
411
        ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1))
412
        ax2.set_yticklabels(ax2.get_yticks())
413
        ax2.set_xticklabels(
414
            ax2.get_xticklabels(),
415
            horizontalalignment='center',
416
            fontweight='light',
417
            fontsize='medium')
418
        ax2.tick_params(length=1, colors='#111111')
419
        for _, spine in ax2.spines.items():
420
            spine.set_visible(True)
421
            spine.set_color(spine_color)
422
423
        # ax3 - Summary
424
        fontax3 = {'color':  '#111111',
425
                   'weight': 'normal',
426
                   'size': 12,
427
                   }
428
        ax3.get_xaxis().set_visible(False)
429
        ax3.get_yaxis().set_visible(False)
430
        ax3.set(frame_on=False)
431
432
        ax3.text(0.1, 0.9, f"Total: {np.round(total_datapoints/1000,1)}K",
433
                 transform=ax3.transAxes,
434
                 fontdict=fontax3)
435
        ax3.text(0.1, 0.7, f"Missing: {np.round(mv_total/1000,1)}K",
436
                 transform=ax3.transAxes,
437
                 fontdict=fontax3)
438
        ax3.text(0.1, 0.5, f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
439
                 transform=ax3.transAxes,
440
                 fontdict=fontax3)
441
        ax3.text(0.1, 0.3, f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
442
                 transform=ax3.transAxes,
443
                 fontdict=fontax3)
444
        ax3.text(0.1, 0.1, f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
445
                 transform=ax3.transAxes,
446
                 fontdict=fontax3)
447
448
        # ax4 - Scatter plot
449
        ax4.get_yaxis().set_visible(False)
450
        for _, spine in ax4.spines.items():
451
            spine.set_color(spine_color)
452
        ax4.tick_params(axis='x', colors='#111111', length=1)
453
454
        ax4.scatter(mv_rows, range(len(mv_rows)), s=mv_rows, c=mv_rows, cmap=cmap, marker=".", vmin=1)
455
        ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
456
        ax4.set_xlim(0, max(mv_rows)+0.5)
457
        ax4.grid(linestyle=':', linewidth=1)
458
459
        gs.figure.suptitle('Missing value plot', x=0.45, y=0.94, fontsize=18, color='#111111')
460
461
        return gs
462