GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Passed
Push — master ( bf399e...620af2 )
by Andreas
01:17
created

klib.describe.corr_mat()   B

Complexity

Conditions 6

Size

Total Lines 72
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 29
nop 6
dl 0
loc 72
rs 8.2506
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
"""
7
8
# Imports
9
from matplotlib.colors import LinearSegmentedColormap, to_rgb
10
import matplotlib.pyplot as plt
11
import matplotlib.ticker as ticker
12
import numpy as np
13
import pandas as pd
14
import scipy
15
import seaborn as sns
16
17
from typing import Any, Dict, Optional, Tuple, Union
18
from klib.utils import (
19
    _corr_selector,
20
    _missing_vals,
21
    _validate_input_bool,
22
    _validate_input_int,
23
    _validate_input_range,
24
    _validate_input_smaller,
25
    _validate_input_sum_larger,
26
)
27
28
29
__all__ = ["cat_plot", "corr_mat", "corr_plot", "dist_plot", "missingval_plot"]
30
31
32
# Functions
33
34
# Categorical Plot
35
def cat_plot(
36
    data: pd.DataFrame,
37
    figsize: Tuple = (18, 18),
38
    top: int = 3,
39
    bottom: int = 3,
40
    bar_color_top: str = "#5ab4ac",
41
    bar_color_bottom: str = "#d8b365",
42
    # cmap: str = "BrBG",
43
):
44
    """ Two-dimensional visualization of the number and frequency of categorical features.
45
46
    Parameters
47
    ----------
48
    data : pd.DataFrame
49
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
50
    information is used to label the plots
51
    figsize : Tuple, optional
52
        Use to control the figure size, by default (18, 18)
53
    top : int, optional
54
        Show the "top" most frequent values in a column, by default 3
55
    bottom : int, optional
56
        Show the "bottom" most frequent values in a column, by default 3
57
    bar_color_top : str, optional
58
        Use to control the color of the bars indicating the most common values, by default "#5ab4ac"
59
    bar_color_bottom : str, optional
60
        Use to control the color of the bars indicating the least common values, by default "#d8b365"
61
    cmap : str, optional
62
        The mapping from data values to color space, by default "BrBG"
63
64
    Returns
65
    -------
66
    Gridspec
67
        gs: Figure with array of Axes objects
68
    """
69
70
    # Validate Inputs
71
    _validate_input_int(top, "top")
72
    _validate_input_int(bottom, "bottom")
73
    _validate_input_range(top, "top", 0, data.shape[1])
74
    _validate_input_range(bottom, "bottom", 0, data.shape[1])
75
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
76
77
    data = pd.DataFrame(data).copy()
78
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
79
    data = data[cols]
80
    for col in data.columns:
81
        if data[col].dtype.name == "category" or data[col].dtype.name == "string":
82
            data[col] = data[col].astype("object")
83
84
    if len(cols) == 0:
85
        print("No columns with categorical data were detected.")
86
87
    fig = plt.figure(figsize=figsize)
88
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
89
90
    for count, col in enumerate(cols):
91
92
        n_unique = data[col].nunique(dropna=True)
93
        value_counts = data[col].value_counts()
94
        lim_top, lim_bot = top, bottom
95
96
        if n_unique < top + bottom:
97
            lim_top = int(n_unique // 2)
98
            lim_bot = int(n_unique // 2) + 1
99
100
        if n_unique <= 2:
101
            lim_top = lim_bot = int(n_unique // 2)
102
103
        value_counts_top = value_counts[0:lim_top]
104
        value_counts_idx_top = value_counts_top.index.tolist()
105
        value_counts_bot = value_counts[-lim_bot:]
106
        value_counts_idx_bot = value_counts_bot.index.tolist()
107
108
        if top == 0:
109
            value_counts_top = value_counts_idx_top = []
110
111
        if bottom == 0:
112
            value_counts_bot = value_counts_idx_bot = []
113
114
        data.loc[data[col].isin(value_counts_idx_top), col] = 3
115
        data.loc[data[col].isin(value_counts_idx_bot), col] = -3
116
        data.loc[((data[col] != 3) & (data[col] != -3)), col] = 0
117
        data[col] = data[col].rolling(2, min_periods=1).mean()
118
119
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
120
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
121
122
        # Barcharts
123
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
124
        ax_top.bar(value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85)
125
        ax_top.bar(value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85)
126
        ax_top.set(frame_on=False)
127
        ax_top.tick_params(axis="x", labelrotation=90)
128
129
        # Summary stats
130
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
131
        plt.subplots_adjust(hspace=0.075)
132
        ax_bottom.get_yaxis().set_visible(False)
133
        ax_bottom.get_xaxis().set_visible(False)
134
        ax_bottom.set(frame_on=False)
135
        ax_bottom.text(
136
            0,
137
            0,
138
            f"Unique values: {n_unique}\n\n"
139
            f"Top {lim_top} vals: {sum(value_counts_top)} ({sum(value_counts_top)/data.shape[0]*100:.1f}%)\n"
140
            f"Bot {lim_bot} vals: {sum(value_counts_bot)} ({sum(value_counts_bot)/data.shape[0]*100:.1f}%)",
141
            transform=ax_bottom.transAxes,
142
            color="#111111",
143
            fontsize=11,
144
        )
145
146
    # Heatmap
147
    # data = data.astype("int")
148
    top_rgb = to_rgb(bar_color_top)
149
    color_white = to_rgb("#FFFFFF")
150
    bot_rgb = to_rgb(bar_color_bottom)
151
    cat_plot_cmap = LinearSegmentedColormap.from_list(
152
        "cat_plot_cmap", [bot_rgb, bot_rgb, color_white, top_rgb, top_rgb], N=100
153
    )
154
    ax_hm = fig.add_subplot(gs[2:, :])
155
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=-4.25, vmax=4.25, ax=ax_hm)
156
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[0::5], -1))
157
    ax_hm.set_yticklabels(ax_hm.get_yticks())
158
    ax_hm.set_xticklabels(
159
        ax_hm.get_xticklabels(), horizontalalignment="center", fontweight="light", fontsize="medium"
160
    )
161
    ax_hm.tick_params(length=1, colors="#111111")
162
163
    gs.figure.suptitle("Categorical data plot", x=0.5, y=0.91, fontsize=18, color="#111111")
164
165
    return gs
166
167
168
# Correlation Matrix
169
def corr_mat(
170
    data: pd.DataFrame,
171
    split: Optional[str] = None,  # Optional[Literal['pos', 'neg', 'high', 'low']] = None,
172
    threshold: float = 0,
173
    target: Optional[Union[pd.DataFrame, pd.Series, np.ndarray, str]] = None,
174
    method: str = "pearson",  # Literal['pearson', 'spearman', 'kendall'] = "pearson",
175
    colored: bool = True,
176
) -> Union[pd.DataFrame, Any]:
177
    """ Returns a color-encoded correlation matrix.
178
179
    Parameters
180
    ----------
181
    data : pd.DataFrame
182
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
183
    information is used to label the plots
184
    split : Optional[str], optional
185
        Type of split to be performed, by default None
186
        {None, 'pos', 'neg', 'high', 'low'}
187
    threshold : float, optional
188
        Value between 0 <= threshold <= 1, by default 0
189
    target : Optional[Union[pd.DataFrame, str]], optional
190
        Specify target for correlation. E.g. label column to generate only the correlations between each feature \
191
        and the label, by default None
192
    method : str, optional
193
        method: {'pearson', 'spearman', 'kendall'}, by default "pearson"
194
        * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
195
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
196
        * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive but \
197
            more robust in smaller dataets than 'spearman'
198
    colored : bool, optional
199
        If True the negative values in the correlation matrix are colored in red, by default True
200
201
    Returns
202
    -------
203
    Union[pd.DataFrame, pd.Styler]
204
        If colored = True - corr: Pandas Styler object
205
        If colored = False - corr: Pandas DataFrame
206
    """
207
208
    # Validate Inputs
209
    _validate_input_range(threshold, "threshold", -1, 1)
210
    _validate_input_bool(colored, "colored")
211
212
    def color_negative_red(val):
213
        color = "#FF3344" if val < 0 else None
214
        return "color: %s" % color
215
216
    data = pd.DataFrame(data)
217
218
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
219
        target_data = []
220
        if isinstance(target, str):
221
            target_data = data[target]
222
            data = data.drop(target, axis=1)
223
224
        elif isinstance(target, (list, pd.Series, np.ndarray)):
225
            target_data = pd.Series(target)
226
            target = target_data.name
227
228
        corr = pd.DataFrame(data.corrwith(target_data, method=method))
229
        corr = corr.sort_values(corr.columns[0], ascending=False)
230
        corr.columns = [target]
231
232
    else:
233
        corr = data.corr(method=method)
234
235
    corr = _corr_selector(corr, split=split, threshold=threshold)
236
237
    if colored:
238
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
239
    else:
240
        return corr
241
242
243
# Correlation matrix / heatmap
244
def corr_plot(
245
    data: pd.DataFrame,
246
    split: Optional[str] = None,
247
    threshold: float = 0,
248
    target: Optional[Union[pd.Series, str]] = None,
249
    method: str = "pearson",
250
    cmap: str = "BrBG",
251
    figsize: Tuple = (12, 10),
252
    annot: bool = True,
253
    dev: bool = False,
254
    **kwargs,
255
):
256
    """ Two-dimensional visualization of the correlation between feature-columns, excluding NA values.
257
258
    Parameters
259
    ----------
260
    data : pd.DataFrame
261
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
262
        information is used to label the plots
263
    split : Optional[str], optional
264
        Type of split to be performed {None, 'pos', 'neg', 'high', 'low'}, by default None
265
            * None: visualize all correlations between the feature-columns
266
            * pos: visualize all positive correlations between the feature-columns above the threshold
267
            * neg: visualize all negative correlations between the feature-columns below the threshold
268
            * high: visualize all correlations between the feature-columns for which abs(corr) > threshold is True
269
            * low: visualize all correlations between the feature-columns for which abs(corr) < threshold is True
270
271
    threshold : float, optional
272
        Value between 0 <= threshold <= 1, by default 0
273
    target : Optional[Union[pd.Series, str]], optional
274
        Specify target for correlation. E.g. label column to generate only the correlations between each feature \
275
        and the label, by default None
276
    method : str, optional
277
        method: {'pearson', 'spearman', 'kendall'}, by default "pearson"
278
            * pearson: measures linear relationships and requires normally distributed and homoscedastic data.
279
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
280
            * kendall: ranked/ordinal correlation, measures monotonic relationships. Computationally more expensive \
281
            but more robust in smaller dataets than 'spearman'.
282
283
    cmap : str, optional
284
        The mapping from data values to color space, matplotlib colormap name or object, or list of colors, by default \
285
        "BrBG"
286
    figsize : Tuple, optional
287
        Use to control the figure size, by default (12, 10)
288
    annot : bool, optional
289
        Use to show or hide annotations, by default True
290
    dev : bool, optional
291
        Display figure settings in the plot by setting dev = True. If False, the settings are not displayed, by \
292
        default False
293
294
    Keyword Arguments : optional
295
        Additional elements to control the visualization of the plot, e.g.:
296
297
            * mask: bool, default True
298
                If set to False the entire correlation matrix, including the upper triangle is shown. Set dev = False \
299
                in this case to avoid overlap.
300
            * vmax: float, default is calculated from the given correlation coefficients.
301
                Value between -1 or vmin <= vmax <= 1, limits the range of the colorbar.
302
            * vmin: float, default is calculated from the given correlation coefficients.
303
                Value between -1 <= vmin <= 1 or vmax, limits the range of the colorbar.
304
            * linewidths: float, default 0.5
305
                Controls the line-width inbetween the squares.
306
            * annot_kws: dict, default {'size' : 10}
307
                Controls the font size of the annotations. Only available when annot = True.
308
            * cbar_kws: dict, default {'shrink': .95, 'aspect': 30}
309
                Controls the size of the colorbar.
310
            * Many more kwargs are available, i.e. 'alpha' to control blending, or options to adjust labels, ticks ...
311
312
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
313
314
    Returns
315
    -------
316
    ax: matplotlib Axes
317
        Returns the Axes object with the plot for further tweaking.
318
    """
319
320
    # Validate Inputs
321
    _validate_input_range(threshold, "threshold", -1, 1)
322
    _validate_input_bool(annot, "annot")
323
    _validate_input_bool(dev, "dev")
324
325
    data = pd.DataFrame(data)
326
327
    corr = corr_mat(data, split=split, threshold=threshold, target=target, method=method, colored=False)
328
329
    mask = np.zeros_like(corr, dtype=np.bool)
330
331
    if target is None:
332
        mask = np.triu(np.ones_like(corr, dtype=np.bool))
333
334
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
335
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
336
337
    fig, ax = plt.subplots(figsize=figsize)
338
339
    # Specify kwargs for the heatmap
340
    kwargs = {
341
        "mask": mask,
342
        "cmap": cmap,
343
        "annot": annot,
344
        "vmax": vmax,
345
        "vmin": vmin,
346
        "linewidths": 0.5,
347
        "annot_kws": {"size": 10},
348
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
349
        **kwargs,
350
    }
351
352
    # Draw heatmap with mask and default settings
353
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
354
355
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
356
357
    # Settings
358
    if dev:
359
        fig.suptitle(
360
            f"\
361
            Settings (dev-mode): \n\
362
            - split-mode: {split} \n\
363
            - threshold: {threshold} \n\
364
            - method: {method} \n\
365
            - annotations: {annot} \n\
366
            - cbar: \n\
367
                - vmax: {vmax} \n\
368
                - vmin: {vmin} \n\
369
            - linewidths: {kwargs['linewidths']} \n\
370
            - annot_kws: {kwargs['annot_kws']} \n\
371
            - cbar_kws: {kwargs['cbar_kws']}",
372
            fontsize=12,
373
            color="gray",
374
            x=0.35,
375
            y=0.85,
376
            ha="left",
377
        )
378
379
    return ax
380
381
382
# Distribution plot
383
def dist_plot(
384
    data: pd.DataFrame,
385
    mean_color: str = "orange",
386
    figsize: Tuple = (14, 2),
387
    fill_range: Tuple = (0.025, 0.975),
388
    hist: bool = False,
389
    bins: int = 10,
390
    showall: bool = False,
391
    kde_kws: Dict[str, Any] = None,
392
    rug_kws: Dict[str, Any] = None,
393
    fill_kws: Dict[str, Any] = None,
394
    font_kws: Dict[str, Any] = None,
395
):
396
    """ Two-dimensional visualization of the distribution of numerical features.
397
398
    Parameters
399
    ----------
400
    data : pd.DataFrame
401
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
402
    information is used to label the plots
403
    mean_color : str, optional
404
        Color of the vertical line indicating the mean of the data, by default "orange"
405
    figsize : Tuple, optional
406
        Controls the figure size, by default (14, 2)
407
    fill_range : Tuple, optional
408
        Set the quantiles for shading. Default spans 95% of the data, which is about two std. deviations \
409
        above and below the mean, by default (0.025, 0.975)
410
    hist : bool, optional
411
        Set to True to display histogram bars in the plot, by default False
412
    bins : int, optional
413
        Specification of the number of hist bins. Requires hist = True, by default 10
414
    showall : bool, optional
415
        Set to True to remove the output limit of 20 plots, by default False
416
    kde_kws : Dict[str, Any], optional
417
        Keyword arguments for kdeplot(), by default {'color': 'k', 'alpha': 0.7, 'linewidth': 1}
418
    rug_kws : Dict[str, Any], optional
419
        Keyword arguments for rugplot(), by default {'color': 'brown', 'alpha': 0.5, 'linewidth': 2, 'height': 0.04}
420
    fill_kws : Dict[str, Any], optional
421
        Keyword arguments to control the fill, by default {'color': 'brown', 'alpha': 0.1}
422
    font_kws : Dict[str, Any], optional
423
        Keyword arguments to control the font, by default {'color':  '#111111', 'weight': 'normal', 'size': 11}
424
425
    Returns
426
    -------
427
    ax: matplotlib Axes
428
        Returns the Axes object with the plot for further tweaking.
429
    """
430
431
    # Validate Inputs
432
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
433
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
434
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
435
    _validate_input_bool(hist, "hist")
436
    _validate_input_int(bins, "bins")
437
    _validate_input_range(bins, "bins", 0, data.shape[0])
438
    _validate_input_bool(showall, "showall")
439
440
    # Handle dictionary defaults
441
    kde_kws = {"alpha": 0.7, "linewidth": 1.5} if kde_kws is None else kde_kws.copy()
442
    rug_kws = (
443
        {"color": "brown", "alpha": 0.5, "linewidth": 2, "height": 0.04}
444
        if rug_kws is None
445
        else rug_kws.copy()
446
    )
447
    fill_kws = {"color": "brown", "alpha": 0.1} if fill_kws is None else fill_kws.copy()
448
    font_kws = {"color": "#111111", "weight": "normal", "size": 11} if font_kws is None else font_kws.copy()
449
450
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
451
    cols = list(data.select_dtypes(include=["number"]).columns)
452
    data = data[cols]
453
454
    if len(cols) == 0:
455
        print("No columns with numeric data were detected.")
456
        return
457
458
    elif len(cols) >= 20 and showall is False:
459
        print(
460
            f"Note: The number of numerical features is very large ({len(cols)}), please consider splitting the data. "
461
            "Showing plots for the first 20 numerical features. Override this by setting showall=True."
462
        )
463
        cols = cols[:20]
464
465
    for col in cols:
466
        dropped_values = data[col].isna().sum()
467
        if dropped_values > 0:
468
            col_data = data[col].dropna(axis=0)
469
            print(f"Dropped {dropped_values} missing values from column {col}.")
470
471
        else:
472
            col_data = data[col]
473
474
        _, ax = plt.subplots(figsize=figsize)
475
        ax = sns.distplot(
476
            col_data,
477
            bins=bins,
478
            hist=hist,
479
            rug=True,
480
            kde_kws=kde_kws,
481
            rug_kws=rug_kws,
482
            hist_kws={"alpha": 0.5, "histtype": "step"},
483
        )
484
485
        # Vertical lines and fill
486
        x, y = ax.lines[0].get_xydata().T
487
        ax.fill_between(
488
            x,
489
            y,
490
            where=((x >= np.quantile(col_data, fill_range[0])) & (x <= np.quantile(col_data, fill_range[1]))),
491
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
492
            **fill_kws,
493
        )
494
495
        mean = np.mean(col_data)
496
        std = scipy.stats.tstd(col_data)
497
        ax.vlines(
498
            x=mean, ymin=0, ymax=np.interp(mean, x, y), ls="dotted", color=mean_color, lw=2, label="mean"
499
        )
500
        ax.vlines(
501
            x=np.median(col_data),
502
            ymin=0,
503
            ymax=np.interp(np.median(col_data), x, y),
504
            ls=":",
505
            color=".3",
506
            label="median",
507
        )
508
        ax.vlines(
509
            x=[mean - std, mean + std],
510
            ymin=0,
511
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
512
            ls=":",
513
            color=".5",
514
            label="\u03BC \u00B1 \u03C3",
515
        )
516
517
        ax.set_ylim(0)
518
        ax.set_xlim(ax.get_xlim()[0] * 1.15, ax.get_xlim()[1] * 1.15)
519
520
        # Annotations and legend
521
        ax.text(0.01, 0.85, f"Mean: {np.round(mean,2)}", fontdict=font_kws, transform=ax.transAxes)
522
        ax.text(0.01, 0.7, f"Std. dev: {np.round(std,2)}", fontdict=font_kws, transform=ax.transAxes)
523
        ax.text(
524
            0.01,
525
            0.55,
526
            f"Skew: {np.round(scipy.stats.skew(col_data),2)}",
527
            fontdict=font_kws,
528
            transform=ax.transAxes,
529
        )
530
        ax.text(
531
            0.01,
532
            0.4,
533
            f"Kurtosis: {np.round(scipy.stats.kurtosis(col_data),2)}",  # Excess Kurtosis
534
            fontdict=font_kws,
535
            transform=ax.transAxes,
536
        )
537
        ax.text(0.01, 0.25, f"Count: {np.round(len(col_data))}", fontdict=font_kws, transform=ax.transAxes)
538
        ax.legend(loc="upper right")
539
540
    return ax
541
542
543
# Missing value plot
544
def missingval_plot(
545
    data: pd.DataFrame,
546
    cmap: str = "PuBuGn",
547
    figsize: Tuple = (20, 20),
548
    sort: bool = False,
549
    spine_color: str = "#EEEEEE",
550
):
551
    """ Two-dimensional visualization of the missing values in a dataset.
552
553
    Parameters
554
    ----------
555
    data : pd.DataFrame
556
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame is provided, the index/column \
557
    information is used to label the plots
558
    cmap : str, optional
559
        Any valid colormap can be used. E.g. 'Greys', 'RdPu'. More information can be found in the matplotlib \
560
        documentation, by default "PuBuGn"
561
    figsize : Tuple, optional
562
        Use to control the figure size, by default (20, 20)
563
    sort : bool, optional
564
        Sort columns based on missing values in descending order and drop columns without any missing values, \
565
        by default False
566
    spine_color : str, optional
567
        Set to 'None' to hide the spines on all plots or use any valid matplotlib color argument, by default "#EEEEEE"
568
569
    Returns
570
    -------
571
    GridSpec
572
        gs: Figure with array of Axes objects
573
    """
574
575
    # Validate Inputs
576
    _validate_input_bool(sort, "sort")
577
578
    data = pd.DataFrame(data)
579
580
    if sort:
581
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
582
        final_cols = (
583
            mv_cols_sorted.drop(mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()).keys().tolist()
584
        )
585
        data = data[final_cols]
586
        print("Displaying only columns with missing values.")
587
588
    # Identify missing values
589
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
590
    total_datapoints = data.shape[0] * data.shape[1]
591
592
    if mv_total == 0:
593
        print("No missing values found in the dataset.")
594
    else:
595
        # Create figure and axes
596
        fig = plt.figure(figsize=figsize)
597
        gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
598
        ax1 = fig.add_subplot(gs[:1, :5])
599
        ax2 = fig.add_subplot(gs[1:, :5])
600
        ax3 = fig.add_subplot(gs[:1, 5:])
601
        ax4 = fig.add_subplot(gs[1:, 5:])
602
603
        # ax1 - Barplot
604
        colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
605
        ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
606
        ax1.get_xaxis().set_visible(False)
607
        ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
608
        ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
609
        ax1.grid(linestyle=":", linewidth=1)
610
        ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0))
611
        ax1.tick_params(axis="y", colors="#111111", length=1)
612
613
        # annotate values on top of the bars
614
        for rect, label in zip(ax1.patches, mv_cols):
615
            height = rect.get_height()
616
            ax1.text(
617
                0.1 + rect.get_x() + rect.get_width() / 2,
618
                height + 0.5,
619
                label,
620
                ha="center",
621
                va="bottom",
622
                rotation="90",
623
                alpha=0.5,
624
                fontsize="11",
625
            )
626
627
        ax1.set_frame_on(True)
628
        for _, spine in ax1.spines.items():
629
            spine.set_visible(True)
630
            spine.set_color(spine_color)
631
        ax1.spines["top"].set_color(None)
632
633
        # ax2 - Heatmap
634
        sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
635
        ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1))
636
        ax2.set_yticklabels(ax2.get_yticks())
637
        ax2.set_xticklabels(
638
            ax2.get_xticklabels(), horizontalalignment="center", fontweight="light", fontsize="12"
639
        )
640
        ax2.tick_params(length=1, colors="#111111")
641
        for _, spine in ax2.spines.items():
642
            spine.set_visible(True)
643
            spine.set_color(spine_color)
644
645
        # ax3 - Summary
646
        fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
647
        ax3.get_xaxis().set_visible(False)
648
        ax3.get_yaxis().set_visible(False)
649
        ax3.set(frame_on=False)
650
651
        ax3.text(
652
            0.025,
653
            0.875,
654
            f"Total: {np.round(total_datapoints/1000,1)}K",
655
            transform=ax3.transAxes,
656
            fontdict=fontax3,
657
        )
658
        ax3.text(
659
            0.025, 0.675, f"Missing: {np.round(mv_total/1000,1)}K", transform=ax3.transAxes, fontdict=fontax3
660
        )
661
        ax3.text(
662
            0.025,
663
            0.475,
664
            f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
665
            transform=ax3.transAxes,
666
            fontdict=fontax3,
667
        )
668
        ax3.text(
669
            0.025,
670
            0.275,
671
            f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
672
            transform=ax3.transAxes,
673
            fontdict=fontax3,
674
        )
675
        ax3.text(
676
            0.025,
677
            0.075,
678
            f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
679
            transform=ax3.transAxes,
680
            fontdict=fontax3,
681
        )
682
683
        # ax4 - Scatter plot
684
        ax4.get_yaxis().set_visible(False)
685
        for _, spine in ax4.spines.items():
686
            spine.set_color(spine_color)
687
        ax4.tick_params(axis="x", colors="#111111", length=1)
688
689
        ax4.scatter(mv_rows, range(len(mv_rows)), s=mv_rows, c=mv_rows, cmap=cmap, marker=".", vmin=1)
690
        ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
691
        ax4.set_xlim(0, max(mv_rows) + 0.5)
692
        ax4.grid(linestyle=":", linewidth=1)
693
694
        gs.figure.suptitle("Missing value plot", x=0.45, y=0.94, fontsize=18, color="#111111")
695
696
        return gs
697