GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Passed
Push — master ( dd7fa7...5946a8 )
by Andreas
01:34 queued 11s
created

klib.describe.dist_plot()   D

Complexity

Conditions 11

Size

Total Lines 139
Code Lines 72

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 72
nop 11
dl 0
loc 139
rs 4.6963
c 0
b 0
f 0

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like klib.describe.dist_plot() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
'''
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
'''
7
8
# Imports
9
import matplotlib.pyplot as plt
10
import matplotlib.ticker as ticker
11
import numpy as np
12
import pandas as pd
13
import scipy
14
import seaborn as sns
15
16
from .clean import drop_missing
17
from .utils import _corr_selector
18
from .utils import _missing_vals
19
from .utils import _validate_input_bool
20
from .utils import _validate_input_int
21
from .utils import _validate_input_range
22
23
24
# Functions
25
26
# Categorical Plot
27
def cat_plot(data, figsize=(10, 14), top=3, bottom=3, bar_color_top='#5ab4ac', bar_color_bottom='#d8b365'):
28
    '''
29
    Two-dimensional visualization of the number and frequency of categorical features.
30
31
    Parameters
32
    ----------
33
34
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
35
    information is used to label the plots.
36
37
    figsize: tuple, default (10, 14)
38
        Use to control the figure size.
39
40
    top: int, default 3
41
        Show the "top" most frequent values in a column.
42
43
    bottom: int, default 3
44
        Show the "bottom" most frequent values in a column.
45
46
    bar_color_top: color, default '#5ab4ac'
47
        Use to control the color of the bars indicating the most common values.
48
49
    bar_color_bottom: color, default '#d8b365'
50
        Use to control the color of the bars indicating the least common values.
51
52
    Returns
53
    -------
54
    gs: Figure with array of Axes objects.
55
56
    '''
57
58
    # Validate Inputs
59
    _validate_input_int(top, 'top')
60
    _validate_input_int(bottom, 'bottom')
61
    _validate_input_range(top, 'top', 0, data.shape[1])
62
    _validate_input_range(bottom, 'bottom', 0, data.shape[1])
63
64
    data = pd.DataFrame(data).copy()
65
    cols = list(data.select_dtypes(exclude=['number']).columns)  # categorical cols
66
    data = data[cols].applymap(str)
67
68
    if len(cols) == 0:
69
        print('No columns with categorical data were detected.')
70
71
    else:
72
        fig = plt.figure(figsize=figsize)
73
        gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.2)
74
75
        for count, col in enumerate(cols):
76
77
            n_unique = data[col].nunique(dropna=False)
78
79
            if n_unique <= min(2, top+bottom):
80
                vals = int(n_unique//2)
81
                value_counts_top = data[col].value_counts(sort=True)[0:vals]
82
                value_counts_idx_top = list(map(str, data[col].value_counts()[0:vals].index.tolist()))
83
                value_counts_bot = data[col].value_counts(sort=True)[-vals:]
84
                value_counts_idx_bot = list(map(str, data[col].value_counts()[-vals:].index.tolist()))
85
86
            else:
87
                value_counts_top = data[col].value_counts(sort=True)[0:top]
88
                value_counts_idx_top = list(map(str, data[col].value_counts()[0:top].index.tolist()))
89
                if bottom == 0:
90
                    value_counts_bot = []
91
                    value_counts_idx_bot = []
92
                else:
93
                    value_counts_bot = data[col].value_counts(sort=True)[-bottom:]
94
                    value_counts_idx_bot = list(map(str, data[col].value_counts()[-bottom:].index.tolist()))
95
96
            data[col][data[col].isin(value_counts_idx_top)] = 2
97
            data[col][data[col].isin(value_counts_idx_bot)] = -2
98
            data[col][~((data[col] == 2) | (data[col] == -2))] = 0
99
100
            # Barcharts
101
            ax_top = fig.add_subplot(gs[:1, count:count+1])
102
            ax_top.bar(value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85)
103
            ax_top.bar(value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85)
104
            ax_top.set(frame_on=False)
105
            ax_top.tick_params(axis='x', labelrotation=90)
106
107
            # Summary stats
108
            ax_bottom = fig.add_subplot(gs[1:2, count:count+1])
109
            ax_bottom.get_yaxis().set_visible(False)
110
            ax_bottom.get_xaxis().set_visible(False)
111
            ax_bottom.set(frame_on=False)
112
            ax_bottom.text(0, 0, f'Unique values: {n_unique}\n\n'
113
                           f'Top {top} vals: {sum(value_counts_top)} ({sum(value_counts_top)/data.shape[0]*100:.1f}%)\n'
114
                           f'Bottom {bottom} vals: {sum(value_counts_bot)} ' +
115
                           f'({sum(value_counts_bot)/data.shape[0]*100:.1f}%)',
116
                           transform=ax_bottom.transAxes, color='#111111', fontsize=11)
117
118
        # Heatmap
119
        data = data.astype('int')
120
        ax_hm = fig.add_subplot(gs[2:, :])
121
        sns.heatmap(data, cmap='BrBG', cbar=False, vmin=-4.25, vmax=4.25, ax=ax_hm)
122
        ax_hm.set_yticks(np.round(ax_hm.get_yticks()[0::5], -1))
123
        ax_hm.set_yticklabels(ax_hm.get_yticks())
124
        ax_hm.set_xticklabels(ax_hm.get_xticklabels(),
125
                              horizontalalignment='center',
126
                              fontweight='light',
127
                              fontsize='medium')
128
        ax_hm.tick_params(length=1, colors='#111111')
129
130
        gs.figure.suptitle('Categorical data plot', x=0.47, y=0.925, fontsize=18, color='#111111')
131
132
        return gs
133
134
135
# Correlation Matrix
136
def corr_mat(data, split=None, threshold=0, method='pearson'):
137
    '''
138
    Returns a color-encoded correlation matrix.
139
140
    Parameters
141
    ----------
142
143
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
144
    information is used to label the plots.
145
146
    split: {None, 'pos', 'neg', 'high', 'low'}, default None
147
        Type of split to be performed.
148
149
    threshold: float, default 0
150
        Value between 0 <= threshold <= 1
151
152
    method: {'pearson', 'spearman', 'kendall'}, default 'pearson'
153
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
154
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
155
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but
156
                    more robus in smaller dataets than 'spearman'.
157
158
    Returns
159
    -------
160
    Pandas Styler object
161
162
    '''
163
164
    # Validate Inputs
165
    _validate_input_range(threshold, 'threshold', -1, 1)
166
167
    def color_negative_red(val):
168
        color = '#FF3344' if val < 0 else None
169
        return 'color: %s' % color
170
171
    data = pd.DataFrame(data)
172
    corr = data.corr(method=method)
173
174
    corr = _corr_selector(corr, split=split, threshold=threshold)
175
176
    return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep='-')
177
178
179
# Correlation matrix / heatmap
180
def corr_plot(data, split=None, threshold=0, target=None, method='pearson', cmap='BrBG', figsize=(12, 10), annot=True,
181
              dev=False, **kwargs):
182
    '''
183
    Two-dimensional visualization of the correlation between feature-columns, excluding NA values.
184
185
    Parameters
186
    ----------
187
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
188
    information is used to label the plots.
189
190
    split: {None, 'pos', 'neg', 'high', 'low'}, default None
191
        Type of split to be performed.
192
193
        * None: visualize all correlations between the feature-columns.
194
        * pos: visualize all positive correlations between the feature-columns above the threshold.
195
        * neg: visualize all negative correlations between the feature-columns below the threshold.
196
        * high: visualize all correlations between the feature-columns for which abs(corr) > threshold is True.
197
        * low: visualize all correlations between the feature-columns for which abs(corr) < threshold is True.
198
199
    threshold: float, default 0
200
        Value between 0 <= threshold <= 1
201
202
    target: string, list, np.array or pd.Series, default None
203
        Specify target for correlation. E.g. label column to generate only the correlations between each feature\
204
        and the label.
205
206
    method: {'pearson', 'spearman', 'kendall'}, default 'pearson'
207
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
208
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
209
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but
210
                   more robust in smaller dataets than 'spearman'.
211
212
    cmap: matplotlib colormap name or object, or list of colors, default 'BrBG'
213
        The mapping from data values to color space.
214
215
    figsize: tuple, default (12, 10)
216
        Use to control the figure size.
217
218
    annot: bool, default True
219
        Use to show or hide annotations.
220
221
    dev: bool, default False
222
        Display figure settings in the plot by setting dev = True. If False, the settings are not displayed.s
223
224
    **kwargs: optional
225
        Additional elements to control the visualization of the plot, e.g.:
226
227
        * mask: bool, default True
228
        If set to False the entire correlation matrix, including the upper triangle is shown. Set dev = False in this \
229
        case to avoid overlap.
230
        * vmax: float, default is calculated from the given correlation coefficients.
231
        Value between -1 or vmin <= vmax <= 1, limits the range of the colorbar.
232
        * vmin: float, default is calculated from the given correlation coefficients.
233
        Value between -1 <= vmin <= 1 or vmax, limits the range of the colorbar.
234
        * linewidths: float, default 0.5
235
        Controls the line-width inbetween the squares.
236
        * annot_kws: dict, default {'size' : 10}
237
        Controls the font size of the annotations. Only available when annot = True.
238
        * cbar_kws: dict, default {'shrink': .95, 'aspect': 30}
239
        Controls the size of the colorbar.
240
        * Many more kwargs are available, i.e. 'alpha' to control blending, or options to adjust labels, ticks ...
241
242
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
243
244
    Returns
245
    -------
246
    ax: matplotlib Axes
247
        Returns the Axes object with the plot for further tweaking.
248
249
    '''
250
251
    # Validate Inputs
252
    _validate_input_range(threshold, 'threshold', -1, 1)
253
    _validate_input_bool(annot, 'annot')
254
    _validate_input_bool(dev, 'dev')
255
256
    data = pd.DataFrame(data)
257
258
    # Obtain correlations
259
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
260
        target_data = []
261
        if isinstance(target, str):
262
            target_data = data[target]
263
            data = data.drop(target, axis=1)
264
265
        elif isinstance(target, (list, pd.Series, np.ndarray)):
266
            target_data = pd.Series(target)
267
268
        corr = pd.DataFrame(data.corrwith(target_data))
269
        corr.rename_axis(target, axis=1, inplace=True)
270
        corr = _corr_selector(corr, split=split, threshold=threshold)
271
        corr = corr.sort_values(corr.columns[0], ascending=False)
272
        vmax = np.round(np.nanmax(corr)-0.05, 2)
273
        vmin = np.round(np.nanmin(corr)+0.05, 2)
274
        mask = False
275
        square = False
276
277
    else:
278
        corr = corr_mat(data, split=split, threshold=threshold, method=method).data
279
280
        mask = np.triu(np.ones_like(corr, dtype=np.bool))  # Generate mask for the upper triangle
281
        square = True
282
283
        vmax = np.round(np.nanmax(corr.where(~mask))-0.05, 2)
284
        vmin = np.round(np.nanmin(corr.where(~mask))+0.05, 2)
285
286
    fig, ax = plt.subplots(figsize=figsize)
287
288
    # Specify kwargs for the heatmap
289
    kwargs = {'mask': mask,
290
              'cmap': cmap,
291
              'annot': annot,
292
              'vmax': vmax,
293
              'vmin': vmin,
294
              'linewidths': .5,
295
              'annot_kws': {'size': 10},
296
              'cbar_kws': {'shrink': .95, 'aspect': 30},
297
              **kwargs}
298
299
    # Draw heatmap with mask and some default settings
300
    sns.heatmap(corr,
301
                center=0,
302
                square=square,
303
                fmt='.2f',
304
                **kwargs
305
                )
306
307
    ax.set_title(f'Feature-correlation ({method})', fontdict={'fontsize': 18})
308
309
    # Display settings
310
    if dev:
311
        fig.suptitle(f"\
312
            Settings (dev-mode): \n\
313
            - split-mode: {split} \n\
314
            - threshold: {threshold} \n\
315
            - method: {method} \n\
316
            - annotations: {annot} \n\
317
            - cbar: \n\
318
                - vmax: {vmax} \n\
319
                - vmin: {vmin} \n\
320
            - linewidths: {kwargs['linewidths']} \n\
321
            - annot_kws: {kwargs['annot_kws']} \n\
322
            - cbar_kws: {kwargs['cbar_kws']}",
323
                     fontsize=12,
324
                     color='gray',
325
                     x=0.35,
326
                     y=0.85,
327
                     ha='left')
328
329
    return ax
330
331
332
# Distribution plot
333
def dist_plot(data, mean_color='orange', figsize=(14, 2), fill_range=(0.025, 0.975), hist=False, bins=None,
334
              showall=False, kde_kws=None, rug_kws=None, fill_kws=None, font_kws=None):
335
    '''
336
    Two-dimensional visualization of the distribution of numerical features.
337
338
    Parameters
339
    ----------
340
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
341
    information is used to label the plots.
342
343
    mean_color: color, default 'orange'
344
        Color of the vertical line indicating the mean of the data.
345
346
    figsize: tuple, default (14, 2)
347
        Use to control the figure size.
348
349
    fill_range: tuple, default (0.025, 0.975)
350
        Use to control set the quantiles for shading. Default spans 95% of the data, which is about two std. deviations\
351
        above and below the mean.
352
353
    hist: bool, default False
354
        Set to True to display histogram bars in the plot.
355
356
    bins: integer, default None
357
        Specification of the number of hist bins. Requires hist = True
358
359
    showall: bool, default False
360
        Set to True to remove the output limit of 20 plots.
361
362
    kdw_kws: dict, default None
363
        Keyword arguments for kdeplot().
364
365
    rug_kws: dict, default None
366
        Keyword arguments for rugplot().
367
368
    fill_kws: dict, default None
369
        Keyword arguments to control the fill.
370
371
    font_kws: dict, default None
372
        Keyword arguments to control the font.
373
374
    Returns
375
    -------
376
    ax: matplotlib Axes
377
        Returns the Axes object with the plot for further tweaking.
378
379
    '''
380
381
    # Validate Inputs
382
    _validate_input_range(fill_range[0], 'fill_range_lower', 0, 1)
383
    _validate_input_range(fill_range[1], 'fill_range_upper', 0, 1)
384
    if fill_range[0] >= fill_range[1]:
385
        raise ValueError('Start value for fill_range must be lower than upper value.')
386
    _validate_input_bool(hist, 'hist')
387
    _validate_input_bool(showall, 'showall')
388
389
    # Handle dictionary defaults
390
    kde_kws = {} if kde_kws is None else kde_kws.copy()
391
    rug_kws = {} if rug_kws is None else rug_kws.copy()
392
    fill_kws = {} if fill_kws is None else fill_kws.copy()
393
    font_kws = {} if font_kws is None else font_kws.copy()
394
395
    data = drop_missing(pd.DataFrame(data).copy())  # drop empty columns and rows
396
    cols = list(data.select_dtypes(include=['number']).columns)  # numeric cols
397
    data = data[cols]
398
399
    # Default settings
400
    kde_kws = {'color': 'k', 'alpha': 0.7, 'linewidth': 1, **kde_kws}
401
    rug_kws = {'color': 'brown', 'alpha': 0.5, 'linewidth': 2, 'height': 0.04, **rug_kws}
402
    fill_kws = {'color': 'brown', 'alpha': 0.1, **fill_kws}
403
    font_kws = {'color':  '#111111', 'weight': 'normal', 'size': 11, **font_kws}
404
405
    if len(cols) == 0:
406
        print('No columns with numeric data were detected.')
407
        ax = None
408
409
    else:
410
        if len(cols) >= 20 and showall is False:
411
            print(f'Note: The number of numerical features is very large ({len(cols)}), please consider splitting the data.\
412
            Showing plots for the first 20 numerical features. Override this by setting showall=True.')
413
            cols = cols[:20]
414
415
        ax = []
416
        for col in cols:
417
            # Drop missing values
418
            dropped_values = data[col].isna().sum()
419
            if dropped_values > 0:
420
                print(f'Dropped {dropped_values} missing values from column {col}.')
421
                col_data = data[col].dropna(axis=0)
422
            else:
423
                col_data = data[col]
424
425
            _, ax = plt.subplots(figsize=figsize)
426
            ax = sns.distplot(col_data, bins=bins, hist=hist, rug=True, kde_kws=kde_kws,
427
                              rug_kws=rug_kws, hist_kws={'alpha': 0.5, 'histtype': 'step'})
428
429
            # Vertical lines and fill
430
            line = ax.lines[0]
431
            x = line.get_xydata()[:, 0]
432
            y = line.get_xydata()[:, 1]
433
            ax.fill_between(x, y,
434
                            where=(
435
                                (x >= np.quantile(col_data, fill_range[0])) &
436
                                (x <= np.quantile(col_data, fill_range[1]))),
437
                            label=f'{fill_range[0]*100:.0f}% - {fill_range[1]*100:.0f}%',
438
                            **fill_kws)
439
440
            ax.vlines(x=np.mean(col_data),
441
                      ymin=0,
442
                      ymax=np.interp(np.mean(col_data), x, y),
443
                      ls='dotted', color=mean_color, lw=2, label='mean')
444
            ax.vlines(x=np.median(col_data),
445
                      ymin=0,
446
                      ymax=np.interp(np.median(col_data), x, y),
447
                      ls=':', color='.3', label='median')
448
            ax.vlines(x=np.quantile(col_data, 0.25),
449
                      ymin=0,
450
                      ymax=np.interp(np.quantile(col_data, 0.25), x, y), ls=':', color='.5', label='25%')
451
            ax.vlines(x=np.quantile(col_data, 0.75),
452
                      ymin=0,
453
                      ymax=np.interp(np.quantile(col_data, 0.75), x, y), ls=':', color='.5', label='75%')
454
455
            ax.set_ylim(0,)
456
            ax.set_xlim(ax.get_xlim()[0]*1.1, ax.get_xlim()[1]*1.1)
457
458
            # Annotations and legend
459
            ax.text(0.01, 0.85, f'Mean: {np.round(np.mean(col_data),2)}',
460
                    fontdict=font_kws, transform=ax.transAxes)
461
            ax.text(0.01, 0.7, f'Std. dev: {np.round(scipy.stats.tstd(col_data),2)}',
462
                    fontdict=font_kws, transform=ax.transAxes)
463
            ax.text(0.01, 0.55, f'Skew: {np.round(scipy.stats.skew(col_data),2)}',
464
                    fontdict=font_kws, transform=ax.transAxes)
465
            ax.text(0.01, 0.4, f'Kurtosis: {np.round(scipy.stats.kurtosis(col_data),2)}',  # Excess Kurtosis
466
                    fontdict=font_kws, transform=ax.transAxes)
467
            ax.text(0.01, 0.25, f'Count: {np.round(len(col_data))}',
468
                    fontdict=font_kws, transform=ax.transAxes)
469
            ax.legend(loc='upper right')
470
471
    return ax
472
473
474
# Missing value plot
475
def missingval_plot(data, cmap='PuBuGn', figsize=(12, 12), sort=False, spine_color='#EEEEEE'):
476
    '''
477
    Two-dimensional visualization of the missing values in a dataset.
478
479
    Parameters
480
    ----------
481
    data: 2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
482
    information is used to label the plots.
483
484
    cmap: colormap, default 'PuBuGn'
485
        Any valid colormap can be used. E.g. 'Greys', 'RdPu'. More information can be found in the matplotlib \
486
        documentation.
487
488
    figsize: tuple, default (20, 12)
489
        Use to control the figure size.
490
491
    sort: bool, default False
492
        Sort columns based on missing values in descending order and drop columns without any missing values
493
494
    spine_color: color, default '#EEEEEE'
495
        Set to 'None' to hide the spines on all plots or use any valid matplotlib color argument.
496
497
    Returns
498
    -------
499
    gs: Figure with array of Axes objects.
500
501
    '''
502
503
    # Validate Inputs
504
    _validate_input_bool(sort, 'sort')
505
506
    data = pd.DataFrame(data)
507
508
    if sort:
509
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
510
        final_cols = mv_cols_sorted.drop(mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()).keys().tolist()
511
        data = data[final_cols]
512
        print('Displaying only columns with missing values.')
513
514
    # Identify missing values
515
    mv_cols = _missing_vals(data)['mv_cols']  # data.isna().sum(axis=0)
516
    mv_rows = _missing_vals(data)['mv_rows']  # data.isna().sum(axis=1)
517
    mv_total = _missing_vals(data)['mv_total']
518
    mv_cols_ratio = _missing_vals(data)['mv_cols_ratio']  # mv_cols / data.shape[0]
519
    total_datapoints = data.shape[0]*data.shape[1]
520
521
    if mv_total == 0:
522
        print('No missing values found in the dataset.')
523
    else:
524
        # Create figure and axes
525
        fig = plt.figure(figsize=figsize)
526
        gs = fig.add_gridspec(nrows=6, ncols=6, left=0.05, wspace=0.05)
527
        ax1 = fig.add_subplot(gs[:1, :5])
528
        ax2 = fig.add_subplot(gs[1:, :5])
529
        ax3 = fig.add_subplot(gs[:1, 5:])
530
        ax4 = fig.add_subplot(gs[1:, 5:])
531
532
        # ax1 - Barplot
533
        colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
534
        ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio)*100, 2), color=colors)
535
        ax1.get_xaxis().set_visible(False)
536
        ax1.set(frame_on=False, xlim=(-.5, len(mv_cols)-0.5))
537
        ax1.set_ylim(0, np.max(mv_cols_ratio)*100)
538
        ax1.grid(linestyle=':', linewidth=1)
539
        ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0))
540
        ax1.tick_params(axis='y', colors='#111111', length=1)
541
542
        # annotate values on top of the bars
543
        for rect, label in zip(ax1.patches, mv_cols):
544
            height = rect.get_height()
545
            ax1.text(.1 + rect.get_x() + rect.get_width() / 2, height+0.5, label,
546
                     ha='center',
547
                     va='bottom',
548
                     rotation='90',
549
                     alpha=0.5,
550
                     fontsize='small')
551
552
        ax1.set_frame_on(True)
553
        for _, spine in ax1.spines.items():
554
            spine.set_visible(True)
555
            spine.set_color(spine_color)
556
        ax1.spines['top'].set_color(None)
557
558
        # ax2 - Heatmap
559
        sns.heatmap(data.isna(), cbar=False, cmap='binary', ax=ax2)
560
        ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1))
561
        ax2.set_yticklabels(ax2.get_yticks())
562
        ax2.set_xticklabels(
563
            ax2.get_xticklabels(),
564
            horizontalalignment='center',
565
            fontweight='light',
566
            fontsize='medium')
567
        ax2.tick_params(length=1, colors='#111111')
568
        for _, spine in ax2.spines.items():
569
            spine.set_visible(True)
570
            spine.set_color(spine_color)
571
572
        # ax3 - Summary
573
        fontax3 = {'color':  '#111111',
574
                   'weight': 'normal',
575
                   'size': 12,
576
                   }
577
        ax3.get_xaxis().set_visible(False)
578
        ax3.get_yaxis().set_visible(False)
579
        ax3.set(frame_on=False)
580
581
        ax3.text(0.1, 0.9, f"Total: {np.round(total_datapoints/1000,1)}K",
582
                 transform=ax3.transAxes,
583
                 fontdict=fontax3)
584
        ax3.text(0.1, 0.7, f"Missing: {np.round(mv_total/1000,1)}K",
585
                 transform=ax3.transAxes,
586
                 fontdict=fontax3)
587
        ax3.text(0.1, 0.5, f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
588
                 transform=ax3.transAxes,
589
                 fontdict=fontax3)
590
        ax3.text(0.1, 0.3, f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
591
                 transform=ax3.transAxes,
592
                 fontdict=fontax3)
593
        ax3.text(0.1, 0.1, f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
594
                 transform=ax3.transAxes,
595
                 fontdict=fontax3)
596
597
        # ax4 - Scatter plot
598
        ax4.get_yaxis().set_visible(False)
599
        for _, spine in ax4.spines.items():
600
            spine.set_color(spine_color)
601
        ax4.tick_params(axis='x', colors='#111111', length=1)
602
603
        ax4.scatter(mv_rows, range(len(mv_rows)), s=mv_rows, c=mv_rows, cmap=cmap, marker=".", vmin=1)
604
        ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
605
        ax4.set_xlim(0, max(mv_rows)+0.5)
606
        ax4.grid(linestyle=':', linewidth=1)
607
608
        gs.figure.suptitle('Missing value plot', x=0.45, y=0.94, fontsize=18, color='#111111')
609
610
        return gs
611