GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Passed
Branch master (e3dc91)
by Andreas
01:34 queued 34s
created

klib.describe.corr_plot()   B

Complexity

Conditions 3

Size

Total Lines 136
Code Lines 43

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 43
nop 10
dl 0
loc 136
rs 8.8478
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
"""
7
8
# Imports
9
import matplotlib.pyplot as plt
10
import matplotlib.ticker as ticker
11
import numpy as np
12
import pandas as pd
13
import scipy
14
import seaborn as sns
15
16
from typing import Any, Dict, Optional, Tuple, Union
17
from klib.utils import (
18
    _corr_selector,
19
    _missing_vals,
20
    _validate_input_bool,
21
    _validate_input_int,
22
    _validate_input_smaller,
23
    _validate_input_range,
24
)
25
26
27
__all__ = ["cat_plot", "corr_mat", "corr_plot", "dist_plot", "missingval_plot"]
28
29
30
# Functions
31
32
# Categorical Plot
33
def cat_plot(
34
    data: pd.DataFrame,
35
    figsize: Tuple = (16, 16),
36
    top: int = 3,
37
    bottom: int = 3,
38
    bar_color_top: str = "#5ab4ac",
39
    bar_color_bottom: str = "#d8b365",
40
    cmap: str = "BrBG",
41
):
42
    """ Two-dimensional visualization of the number and frequency of categorical features.
43
44
    Parameters
45
    ----------
46
    data : pd.DataFrame
47
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
48
    information is used to label the plots
49
    figsize : Tuple, optional
50
        Use to control the figure size, by default (16, 16)
51
    top : int, optional
52
        Show the "top" most frequent values in a column, by default 3
53
    bottom : int, optional
54
        Show the "bottom" most frequent values in a column, by default 3
55
    bar_color_top : str, optional
56
        Use to control the color of the bars indicating the most common values, by default "#5ab4ac"
57
    bar_color_bottom : str, optional
58
        Use to control the color of the bars indicating the least common values, by default "#d8b365"
59
    cmap : str, optional
60
        The mapping from data values to color space, by default "BrBG"
61
62
    Returns
63
    -------
64
    Gridspec
65
        gs: Figure with array of Axes objects
66
    """
67
68
    # Validate Inputs
69
    _validate_input_int(top, "top")
70
    _validate_input_int(bottom, "bottom")
71
    _validate_input_range(top, "top", 0, data.shape[1])
72
    _validate_input_range(bottom, "bottom", 0, data.shape[1])
73
74
    data = pd.DataFrame(data).copy()
75
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
76
    data = data[cols]
77
    for col in data.columns:
78
        if data[col].dtype.name == "category" or data[col].dtype.name == "string":
79
            data[col] = data[col].astype("object")
80
81
    if len(cols) == 0:
82
        print("No columns with categorical data were detected.")
83
84
    fig = plt.figure(figsize=figsize)
85
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.2)
86
87
    for count, col in enumerate(cols):
88
89
        n_unique = data[col].nunique(dropna=False)
90
        value_counts = data[col].value_counts()
91
        lim_top, lim_bot = top, bottom
92
93
        if n_unique < top + bottom:
94
            lim_top = lim_bot = int(n_unique // 2)
95
96
        value_counts_top = value_counts[0:lim_top]
97
        value_counts_idx_top = value_counts_top.index.tolist()
98
        value_counts_bot = value_counts[-lim_bot:]
99
        value_counts_idx_bot = value_counts_bot.index.tolist()
100
101
        if top == 0:
102
            value_counts_top = value_counts_idx_top = []
103
104
        elif bottom == 0:
105
            value_counts_bot = value_counts_idx_bot = []
106
107
        data.loc[data[col].isin(value_counts_idx_top), col] = 2
108
        data.loc[data[col].isin(value_counts_idx_bot), col] = -2
109
        data.loc[((data[col] != 2) & (data[col] != -2)), col] = 0
110
111
        # Barcharts
112
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
113
        ax_top.bar(value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85)
114
        ax_top.bar(value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85)
115
        ax_top.set(frame_on=False)
116
        ax_top.tick_params(axis="x", labelrotation=90)
117
118
        # Summary stats
119
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
120
        ax_bottom.get_yaxis().set_visible(False)
121
        ax_bottom.get_xaxis().set_visible(False)
122
        ax_bottom.set(frame_on=False)
123
        ax_bottom.text(
124
            0,
125
            0,
126
            f"Unique values: {n_unique}\n\n"
127
            f"Top {top} vals: {sum(value_counts_top)} ({sum(value_counts_top)/data.shape[0]*100:.1f}%)\n"
128
            f"Bot {bottom} vals: {sum(value_counts_bot)} "
129
            + f"({sum(value_counts_bot)/data.shape[0]*100:.1f}%)",
130
            transform=ax_bottom.transAxes,
131
            color="#111111",
132
            fontsize=11,
133
        )
134
135
    # Heatmap
136
    data = data.astype("int")
137
    ax_hm = fig.add_subplot(gs[2:, :])
138
    sns.heatmap(data, cmap=cmap, cbar=False, vmin=-4.25, vmax=4.25, ax=ax_hm)
139
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[0::5], -1))
140
    ax_hm.set_yticklabels(ax_hm.get_yticks())
141
    ax_hm.set_xticklabels(
142
        ax_hm.get_xticklabels(), horizontalalignment="center", fontweight="light", fontsize="medium"
143
    )
144
    ax_hm.tick_params(length=1, colors="#111111")
145
146
    gs.figure.suptitle("Categorical data plot", x=0.47, y=0.925, fontsize=18, color="#111111")
147
148
    return gs
149
150
151
# Correlation Matrix
152
def corr_mat(
153
    data: pd.DataFrame,
154
    split: Optional[str] = None,  # Optional[Literal['pos', 'neg', 'high', 'low']] = None,
155
    threshold: float = 0,
156
    target: Optional[Union[pd.DataFrame, pd.Series, np.ndarray, str]] = None,
157
    method: str = "pearson",  # Literal['pearson', 'spearman', 'kendall'] = "pearson",
158
    colored: bool = True,
159
) -> Union[pd.DataFrame, Any]:
160
    """ Returns a color-encoded correlation matrix.
161
162
    Parameters
163
    ----------
164
    data : pd.DataFrame
165
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
166
    information is used to label the plots
167
    split : Optional[str], optional
168
        Type of split to be performed, by default None
169
        {None, 'pos', 'neg', 'high', 'low'}
170
    threshold : float, optional
171
        Value between 0 <= threshold <= 1, by default 0
172
    target : Optional[Union[pd.DataFrame, str]], optional
173
        Specify target for correlation. E.g. label column to generate only the correlations between each feature \
174
        and the label, by default None
175
    method : str, optional
176
        method: {'pearson', 'spearman', 'kendall'}, by default "pearson"
177
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
178
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
179
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but \
180
            more robust in smaller dataets than 'spearman'
181
    colored : bool, optional
182
        If True the negative values in the correlation matrix are colored in red, by default True
183
184
    Returns
185
    -------
186
    Union[pd.DataFrame, pd.Styler]
187
        If colored = True - corr: Pandas Styler object
188
        If colored = False - corr: Pandas DataFrame
189
    """
190
191
    # Validate Inputs
192
    _validate_input_range(threshold, "threshold", -1, 1)
193
    _validate_input_bool(colored, "colored")
194
195
    def color_negative_red(val):
196
        color = "#FF3344" if val < 0 else None
197
        return "color: %s" % color
198
199
    data = pd.DataFrame(data)
200
201
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
202
        target_data = []
203
        if isinstance(target, str):
204
            target_data = data[target]
205
            data = data.drop(target, axis=1)
206
207
        elif isinstance(target, (list, pd.Series, np.ndarray)):
208
            target_data = pd.Series(target)
209
            target = target_data.name
210
211
        corr = pd.DataFrame(data.corrwith(target_data))
212
        corr = corr.sort_values(corr.columns[0], ascending=False)
213
        corr.columns = [target]
214
215
    else:
216
        corr = data.corr(method=method)
217
218
    corr = _corr_selector(corr, split=split, threshold=threshold)
219
220
    if colored:
221
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
222
    else:
223
        return corr
224
225
226
# Correlation matrix / heatmap
227
def corr_plot(
228
    data: pd.DataFrame,
229
    split: Optional[str] = None,
230
    threshold: float = 0,
231
    target: Optional[Union[pd.Series, str]] = None,
232
    method: str = "pearson",
233
    cmap: str = "BrBG",
234
    figsize: Tuple = (12, 10),
235
    annot: bool = True,
236
    dev: bool = False,
237
    **kwargs,
238
):
239
    """ Two-dimensional visualization of the correlation between feature-columns, excluding NA values.
240
241
    Parameters
242
    ----------
243
    data : pd.DataFrame
244
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
245
        information is used to label the plots
246
    split : Optional[str], optional
247
        Type of split to be performed {None, 'pos', 'neg', 'high', 'low'}, by default None
248
            * None: visualize all correlations between the feature-columns
249
            * pos: visualize all positive correlations between the feature-columns above the threshold
250
            * neg: visualize all negative correlations between the feature-columns below the threshold
251
            * high: visualize all correlations between the feature-columns for which abs(corr) > threshold is True
252
            * low: visualize all correlations between the feature-columns for which abs(corr) < threshold is True
253
254
    threshold : float, optional
255
        Value between 0 <= threshold <= 1, by default 0
256
    target : Optional[Union[pd.Series, str]], optional
257
        Specify target for correlation. E.g. label column to generate only the correlations between each feature \
258
        and the label, by default None
259
    method : str, optional
260
        method: {'pearson', 'spearman', 'kendall'}, by default "pearson"
261
            * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
262
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
263
            * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive \
264
            but more robust in smaller dataets than 'spearman'.
265
266
    cmap : str, optional
267
        The mapping from data values to color space, matplotlib colormap name or object, or list of colors, by default \
268
        "BrBG"
269
    figsize : Tuple, optional
270
        Use to control the figure size, by default (12, 10)
271
    annot : bool, optional
272
        Use to show or hide annotations, by default True
273
    dev : bool, optional
274
        Display figure settings in the plot by setting dev = True. If False, the settings are not displayed, by \
275
        default False
276
277
    Keyword Arguments : optional
278
        Additional elements to control the visualization of the plot, e.g.:
279
280
            * mask: bool, default True
281
                If set to False the entire correlation matrix, including the upper triangle is shown. Set dev = False \
282
                in this case to avoid overlap.
283
            * vmax: float, default is calculated from the given correlation coefficients.
284
                Value between -1 or vmin <= vmax <= 1, limits the range of the colorbar.
285
            * vmin: float, default is calculated from the given correlation coefficients.
286
                Value between -1 <= vmin <= 1 or vmax, limits the range of the colorbar.
287
            * linewidths: float, default 0.5
288
                Controls the line-width inbetween the squares.
289
            * annot_kws: dict, default {'size' : 10}
290
                Controls the font size of the annotations. Only available when annot = True.
291
            * cbar_kws: dict, default {'shrink': .95, 'aspect': 30}
292
                Controls the size of the colorbar.
293
            * Many more kwargs are available, i.e. 'alpha' to control blending, or options to adjust labels, ticks ...
294
295
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
296
297
    Returns
298
    -------
299
    ax: matplotlib Axes
300
        Returns the Axes object with the plot for further tweaking.
301
    """
302
303
    # Validate Inputs
304
    _validate_input_range(threshold, "threshold", -1, 1)
305
    _validate_input_bool(annot, "annot")
306
    _validate_input_bool(dev, "dev")
307
308
    data = pd.DataFrame(data)
309
310
    corr = corr_mat(data, split=split, threshold=threshold, target=target, method=method, colored=False)
311
312
    mask = np.zeros_like(corr, dtype=np.bool)
313
314
    if target is None:
315
        mask = np.triu(np.ones_like(corr, dtype=np.bool))
316
317
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
318
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
319
320
    fig, ax = plt.subplots(figsize=figsize)
321
322
    # Specify kwargs for the heatmap
323
    kwargs = {
324
        "mask": mask,
325
        "cmap": cmap,
326
        "annot": annot,
327
        "vmax": vmax,
328
        "vmin": vmin,
329
        "linewidths": 0.5,
330
        "annot_kws": {"size": 10},
331
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
332
        **kwargs,
333
    }
334
335
    # Draw heatmap with mask and default settings
336
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
337
338
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
339
340
    # Settings
341
    if dev:
342
        fig.suptitle(
343
            f"\
344
            Settings (dev-mode): \n\
345
            - split-mode: {split} \n\
346
            - threshold: {threshold} \n\
347
            - method: {method} \n\
348
            - annotations: {annot} \n\
349
            - cbar: \n\
350
                - vmax: {vmax} \n\
351
                - vmin: {vmin} \n\
352
            - linewidths: {kwargs['linewidths']} \n\
353
            - annot_kws: {kwargs['annot_kws']} \n\
354
            - cbar_kws: {kwargs['cbar_kws']}",
355
            fontsize=12,
356
            color="gray",
357
            x=0.35,
358
            y=0.85,
359
            ha="left",
360
        )
361
362
    return ax
363
364
365
# Distribution plot
366
def dist_plot(
367
    data: pd.DataFrame,
368
    mean_color: str = "orange",
369
    figsize: Tuple = (14, 2),
370
    fill_range: Tuple = (0.025, 0.975),
371
    hist: bool = False,
372
    bins: int = 10,
373
    showall: bool = False,
374
    kde_kws: Dict[str, Any] = None,
375
    rug_kws: Dict[str, Any] = None,
376
    fill_kws: Dict[str, Any] = None,
377
    font_kws: Dict[str, Any] = None,
378
):
379
    """ Two-dimensional visualization of the distribution of numerical features.
380
381
    Parameters
382
    ----------
383
    data : pd.DataFrame
384
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
385
    information is used to label the plots
386
    mean_color : str, optional
387
        Color of the vertical line indicating the mean of the data, by default "orange"
388
    figsize : Tuple, optional
389
        Controls the figure size, by default (14, 2)
390
    fill_range : Tuple, optional
391
        Set the quantiles for shading. Default spans 95% of the data, which is about two std. deviations \
392
        above and below the mean, by default (0.025, 0.975)
393
    hist : bool, optional
394
        Set to True to display histogram bars in the plot, by default False
395
    bins : int, optional
396
        Specification of the number of hist bins. Requires hist = True, by default 10
397
    showall : bool, optional
398
        Set to True to remove the output limit of 20 plots, by default False
399
    kde_kws : Dict[str, Any], optional
400
        Keyword arguments for kdeplot(), by default {'color': 'k', 'alpha': 0.7, 'linewidth': 1}
401
    rug_kws : Dict[str, Any], optional
402
        Keyword arguments for rugplot(), by default {'color': 'brown', 'alpha': 0.5, 'linewidth': 2, 'height': 0.04}
403
    fill_kws : Dict[str, Any], optional
404
        Keyword arguments to control the fill, by default {'color': 'brown', 'alpha': 0.1}
405
    font_kws : Dict[str, Any], optional
406
        Keyword arguments to control the font, by default {'color':  '#111111', 'weight': 'normal', 'size': 11}
407
408
    Returns
409
    -------
410
    [type]
411
        [description]
412
    """
413
414
    # Validate Inputs
415
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
416
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
417
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
418
    _validate_input_bool(hist, "hist")
419
    _validate_input_int(bins, "bins")
420
    _validate_input_range(bins, "bins", 0, data.shape[0])
421
    _validate_input_bool(showall, "showall")
422
423
    # Handle dictionary defaults
424
    kde_kws = {"alpha": 0.7, "linewidth": 1.5} if kde_kws is None else kde_kws.copy()
425
    rug_kws = (
426
        {"color": "brown", "alpha": 0.5, "linewidth": 2, "height": 0.04}
427
        if rug_kws is None
428
        else rug_kws.copy()
429
    )
430
    fill_kws = {"color": "brown", "alpha": 0.1} if fill_kws is None else fill_kws.copy()
431
    font_kws = {"color": "#111111", "weight": "normal", "size": 11} if font_kws is None else font_kws.copy()
432
433
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
434
    cols = list(data.select_dtypes(include=["number"]).columns)
435
    data = data[cols]
436
437
    if len(cols) == 0:
438
        print("No columns with numeric data were detected.")
439
        return
440
441
    elif len(cols) >= 20 and showall is False:
442
        print(
443
            f"Note: The number of numerical features is very large ({len(cols)}), please consider splitting the data. "
444
            "Showing plots for the first 20 numerical features. Override this by setting showall=True."
445
        )
446
        cols = cols[:20]
447
448
    for col in cols:
449
        dropped_values = data[col].isna().sum()
450
        if dropped_values > 0:
451
            col_data = data[col].dropna(axis=0)
452
            print(f"Dropped {dropped_values} missing values from column {col}.")
453
454
        else:
455
            col_data = data[col]
456
457
        _, ax = plt.subplots(figsize=figsize)
458
        ax = sns.distplot(
459
            col_data,
460
            bins=bins,
461
            hist=hist,
462
            rug=True,
463
            kde_kws=kde_kws,
464
            rug_kws=rug_kws,
465
            hist_kws={"alpha": 0.5, "histtype": "step"},
466
        )
467
468
        # Vertical lines and fill
469
        x, y = ax.lines[0].get_xydata().T
470
        ax.fill_between(
471
            x,
472
            y,
473
            where=((x >= np.quantile(col_data, fill_range[0])) & (x <= np.quantile(col_data, fill_range[1]))),
474
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
475
            **fill_kws,
476
        )
477
478
        mean = np.mean(col_data)
479
        std = scipy.stats.tstd(col_data)
480
        ax.vlines(
481
            x=mean, ymin=0, ymax=np.interp(mean, x, y), ls="dotted", color=mean_color, lw=2, label="mean"
482
        )
483
        ax.vlines(
484
            x=np.median(col_data),
485
            ymin=0,
486
            ymax=np.interp(np.median(col_data), x, y),
487
            ls=":",
488
            color=".3",
489
            label="median",
490
        )
491
        ax.vlines(
492
            x=[mean - std, mean + std],
493
            ymin=0,
494
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
495
            ls=":",
496
            color=".5",
497
            label="\u03BC \u00B1 \u03C3",
498
        )
499
500
        ax.set_ylim(0)
501
        ax.set_xlim(ax.get_xlim()[0] * 1.15, ax.get_xlim()[1] * 1.15)
502
503
        # Annotations and legend
504
        ax.text(0.01, 0.85, f"Mean: {np.round(mean,2)}", fontdict=font_kws, transform=ax.transAxes)
505
        ax.text(0.01, 0.7, f"Std. dev: {np.round(std,2)}", fontdict=font_kws, transform=ax.transAxes)
506
        ax.text(
507
            0.01,
508
            0.55,
509
            f"Skew: {np.round(scipy.stats.skew(col_data),2)}",
510
            fontdict=font_kws,
511
            transform=ax.transAxes,
512
        )
513
        ax.text(
514
            0.01,
515
            0.4,
516
            f"Kurtosis: {np.round(scipy.stats.kurtosis(col_data),2)}",  # Excess Kurtosis
517
            fontdict=font_kws,
518
            transform=ax.transAxes,
519
        )
520
        ax.text(0.01, 0.25, f"Count: {np.round(len(col_data))}", fontdict=font_kws, transform=ax.transAxes)
521
        ax.legend(loc="upper right")
522
523
    return ax
524
525
526
# Missing value plot
527
def missingval_plot(
528
    data: pd.DataFrame,
529
    cmap: str = "PuBuGn",
530
    figsize: Tuple = (20, 20),
531
    sort: bool = False,
532
    spine_color: str = "#EEEEEE",
533
):
534
    """ Two-dimensional visualization of the missing values in a dataset.
535
536
    Parameters
537
    ----------
538
    data : pd.DataFrame
539
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
540
    information is used to label the plots
541
    cmap : str, optional
542
        Any valid colormap can be used. E.g. 'Greys', 'RdPu'. More information can be found in the matplotlib \
543
        documentation, by default "PuBuGn"
544
    figsize : Tuple, optional
545
        Use to control the figure size, by default (20, 20)
546
    sort : bool, optional
547
        Sort columns based on missing values in descending order and drop columns without any missing values, \
548
        by default False
549
    spine_color : str, optional
550
        Set to 'None' to hide the spines on all plots or use any valid matplotlib color argument, by default "#EEEEEE"
551
552
    Returns
553
    -------
554
    GridSpec
555
        gs: Figure with array of Axes objects
556
    """
557
558
    # Validate Inputs
559
    _validate_input_bool(sort, "sort")
560
561
    data = pd.DataFrame(data)
562
563
    if sort:
564
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
565
        final_cols = (
566
            mv_cols_sorted.drop(mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()).keys().tolist()
567
        )
568
        data = data[final_cols]
569
        print("Displaying only columns with missing values.")
570
571
    # Identify missing values
572
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
573
    total_datapoints = data.shape[0] * data.shape[1]
574
575
    if mv_total == 0:
576
        print("No missing values found in the dataset.")
577
    else:
578
        # Create figure and axes
579
        fig = plt.figure(figsize=figsize)
580
        gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
581
        ax1 = fig.add_subplot(gs[:1, :5])
582
        ax2 = fig.add_subplot(gs[1:, :5])
583
        ax3 = fig.add_subplot(gs[:1, 5:])
584
        ax4 = fig.add_subplot(gs[1:, 5:])
585
586
        # ax1 - Barplot
587
        colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
588
        ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
589
        ax1.get_xaxis().set_visible(False)
590
        ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
591
        ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
592
        ax1.grid(linestyle=":", linewidth=1)
593
        ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0))
594
        ax1.tick_params(axis="y", colors="#111111", length=1)
595
596
        # annotate values on top of the bars
597
        for rect, label in zip(ax1.patches, mv_cols):
598
            height = rect.get_height()
599
            ax1.text(
600
                0.1 + rect.get_x() + rect.get_width() / 2,
601
                height + 0.5,
602
                label,
603
                ha="center",
604
                va="bottom",
605
                rotation="90",
606
                alpha=0.5,
607
                fontsize="11",
608
            )
609
610
        ax1.set_frame_on(True)
611
        for _, spine in ax1.spines.items():
612
            spine.set_visible(True)
613
            spine.set_color(spine_color)
614
        ax1.spines["top"].set_color(None)
615
616
        # ax2 - Heatmap
617
        sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
618
        ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1))
619
        ax2.set_yticklabels(ax2.get_yticks())
620
        ax2.set_xticklabels(
621
            ax2.get_xticklabels(), horizontalalignment="center", fontweight="light", fontsize="12"
622
        )
623
        ax2.tick_params(length=1, colors="#111111")
624
        for _, spine in ax2.spines.items():
625
            spine.set_visible(True)
626
            spine.set_color(spine_color)
627
628
        # ax3 - Summary
629
        fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
630
        ax3.get_xaxis().set_visible(False)
631
        ax3.get_yaxis().set_visible(False)
632
        ax3.set(frame_on=False)
633
634
        ax3.text(
635
            0.025,
636
            0.875,
637
            f"Total: {np.round(total_datapoints/1000,1)}K",
638
            transform=ax3.transAxes,
639
            fontdict=fontax3,
640
        )
641
        ax3.text(
642
            0.025, 0.675, f"Missing: {np.round(mv_total/1000,1)}K", transform=ax3.transAxes, fontdict=fontax3
643
        )
644
        ax3.text(
645
            0.025,
646
            0.475,
647
            f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
648
            transform=ax3.transAxes,
649
            fontdict=fontax3,
650
        )
651
        ax3.text(
652
            0.025,
653
            0.275,
654
            f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
655
            transform=ax3.transAxes,
656
            fontdict=fontax3,
657
        )
658
        ax3.text(
659
            0.025,
660
            0.075,
661
            f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
662
            transform=ax3.transAxes,
663
            fontdict=fontax3,
664
        )
665
666
        # ax4 - Scatter plot
667
        ax4.get_yaxis().set_visible(False)
668
        for _, spine in ax4.spines.items():
669
            spine.set_color(spine_color)
670
        ax4.tick_params(axis="x", colors="#111111", length=1)
671
672
        ax4.scatter(mv_rows, range(len(mv_rows)), s=mv_rows, c=mv_rows, cmap=cmap, marker=".", vmin=1)
673
        ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
674
        ax4.set_xlim(0, max(mv_rows) + 0.5)
675
        ax4.grid(linestyle=":", linewidth=1)
676
677
        gs.figure.suptitle("Missing value plot", x=0.45, y=0.94, fontsize=18, color="#111111")
678
679
        return gs
680