GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Passed
Push — main ( 126b49...6e0315 )
by Andreas
01:46 queued 11s
created

klib.describe.cat_plot()   C

Complexity

Conditions 9

Size

Total Lines 142
Code Lines 85

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 85
nop 6
dl 0
loc 142
rs 5.1575
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
"""
7
8
# Imports
9
import matplotlib.pyplot as plt
10
import matplotlib.ticker as ticker
11
import numpy as np
12
import pandas as pd
13
import scipy
14
import seaborn as sns
15
from matplotlib.colors import LinearSegmentedColormap, to_rgb
16
from typing import Any, Dict, Optional, Tuple, Union
17
18
from klib.utils import (
19
    _corr_selector,
20
    _missing_vals,
21
    _validate_input_bool,
22
    _validate_input_int,
23
    _validate_input_range,
24
    _validate_input_smaller,
25
    _validate_input_sum_larger,
26
)
27
28
__all__ = ["cat_plot", "corr_mat", "corr_plot", "dist_plot", "missingval_plot"]
29
30
31
# Functions
32
33
# Categorical Plot
34
def cat_plot(
35
    data: pd.DataFrame,
36
    figsize: Tuple = (18, 18),
37
    top: int = 3,
38
    bottom: int = 3,
39
    bar_color_top: str = "#5ab4ac",
40
    bar_color_bottom: str = "#d8b365",
41
):
42
    """ Two-dimensional visualization of the number and frequency of categorical features.
43
44
    Parameters
45
    ----------
46
    data : pd.DataFrame
47
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
48
        is provided, the index/column information is used to label the plots
49
    figsize : Tuple, optional
50
        Use to control the figure size, by default (18, 18)
51
    top : int, optional
52
        Show the "top" most frequent values in a column, by default 3
53
    bottom : int, optional
54
        Show the "bottom" most frequent values in a column, by default 3
55
    bar_color_top : str, optional
56
        Use to control the color of the bars indicating the most common values, by \
57
        default "#5ab4ac"
58
    bar_color_bottom : str, optional
59
        Use to control the color of the bars indicating the least common values, by \
60
        default "#d8b365"
61
    cmap : str, optional
62
        The mapping from data values to color space, by default "BrBG"
63
64
    Returns
65
    -------
66
    Gridspec
67
        gs: Figure with array of Axes objects
68
    """
69
70
    # Validate Inputs
71
    _validate_input_int(top, "top")
72
    _validate_input_int(bottom, "bottom")
73
    _validate_input_range(top, "top", 0, data.shape[1])
74
    _validate_input_range(bottom, "bottom", 0, data.shape[1])
75
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
76
77
    data = pd.DataFrame(data).copy()
78
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
79
    data = data[cols]
80
81
    if len(cols) == 0:
82
        print("No columns with categorical data were detected.")
83
        return None
84
85
    for col in data.columns:
86
        if data[col].dtype.name in ("category", "string"):
87
            data[col] = data[col].astype("object")
88
89
    fig = plt.figure(figsize=figsize)
90
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
91
92
    for count, col in enumerate(cols):
93
        n_unique = data[col].nunique(dropna=True)
94
        value_counts = data[col].value_counts()
95
        lim_top, lim_bot = top, bottom
96
97
        if n_unique < top + bottom:
98
            lim_top = int(n_unique // 2)
99
            lim_bot = int(n_unique // 2) + 1
100
101
        if n_unique <= 2:
102
            lim_top = lim_bot = int(n_unique // 2)
103
104
        value_counts_top = value_counts[0:lim_top]
105
        value_counts_idx_top = value_counts_top.index.tolist()
106
        value_counts_bot = value_counts[-lim_bot:]
107
        value_counts_idx_bot = value_counts_bot.index.tolist()
108
109
        if top == 0:
110
            value_counts_top = value_counts_idx_top = []
111
112
        if bottom == 0:
113
            value_counts_bot = value_counts_idx_bot = []
114
115
        data.loc[data[col].isin(value_counts_idx_top), col] = 10
116
        data.loc[data[col].isin(value_counts_idx_bot), col] = 0
117
        data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5
118
        data[col] = data[col].rolling(2, min_periods=1).mean()
119
120
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
121
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
122
        sum_top = sum(value_counts_top)
123
        sum_bot = sum(value_counts_bot)
124
125
        # Barcharts
126
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
127
        ax_top.bar(
128
            value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85
129
        )
130
        ax_top.bar(
131
            value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85
132
        )
133
        ax_top.set(frame_on=False)
134
        ax_top.tick_params(axis="x", labelrotation=90)
135
136
        # Summary stats
137
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
138
        plt.subplots_adjust(hspace=0.075)
139
        ax_bottom.get_yaxis().set_visible(False)
140
        ax_bottom.get_xaxis().set_visible(False)
141
        ax_bottom.set(frame_on=False)
142
        ax_bottom.text(
143
            0,
144
            0,
145
            f"Unique values: {n_unique}\n\n"
146
            f"Top {lim_top} vals: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n"
147
            f"Bot {lim_bot} vals: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)",
148
            transform=ax_bottom.transAxes,
149
            color="#111111",
150
            fontsize=11,
151
        )
152
153
    # Heatmap
154
    color_bot_rgb = to_rgb(bar_color_bottom)
155
    color_white = to_rgb("#FFFFFF")
156
    color_top_rgb = to_rgb(bar_color_top)
157
    cat_plot_cmap = LinearSegmentedColormap.from_list(
158
        "cat_plot_cmap", [color_bot_rgb, color_white, color_top_rgb], N=200
159
    )
160
    ax_hm = fig.add_subplot(gs[2:, :])
161
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm)
162
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[0::5], -1))
163
    ax_hm.set_yticklabels(ax_hm.get_yticks())
164
    ax_hm.set_xticklabels(
165
        ax_hm.get_xticklabels(),
166
        horizontalalignment="center",
167
        fontweight="light",
168
        fontsize="medium",
169
    )
170
    ax_hm.tick_params(length=1, colors="#111111")
171
    gs.figure.suptitle(
172
        "Categorical data plot", x=0.5, y=0.91, fontsize=18, color="#111111"
173
    )
174
175
    return gs
176
177
178
# Correlation Matrix
179
def corr_mat(
180
    data: pd.DataFrame,
181
    split: Optional[
182
        str
183
    ] = None,  # Optional[Literal['pos', 'neg', 'high', 'low']] = None,
184
    threshold: float = 0,
185
    target: Optional[Union[pd.DataFrame, pd.Series, np.ndarray, str]] = None,
186
    method: str = "pearson",  # Literal['pearson', 'spearman', 'kendall'] = "pearson",
187
    colored: bool = True,
188
) -> Union[pd.DataFrame, Any]:
189
    """ Returns a color-encoded correlation matrix.
190
191
    Parameters
192
    ----------
193
    data : pd.DataFrame
194
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
195
        is provided, the index/column information is used to label the plots
196
    split : Optional[str], optional
197
        Type of split to be performed, by default None
198
        {None, "pos", "neg", "high", "low"}
199
    threshold : float, optional
200
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
201
        split = "high" or split = "low", in which case default is 0.3
202
    target : Optional[Union[pd.DataFrame, str]], optional
203
        Specify target for correlation. E.g. label column to generate only the \
204
        correlations between each feature and the label, by default None
205
    method : str, optional
206
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
207
        * pearson: measures linear relationships and requires normally distributed \
208
            and homoscedastic data.
209
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
210
        * kendall: ranked/ordinal correlation, measures monotonic relationships. \
211
            Computationally more expensive but more robust in smaller dataets than \
212
            "spearman"
213
    colored : bool, optional
214
        If True the negative values in the correlation matrix are colored in red, by \
215
        default True
216
217
    Returns
218
    -------
219
    Union[pd.DataFrame, pd.Styler]
220
        If colored = True - corr: Pandas Styler object
221
        If colored = False - corr: Pandas DataFrame
222
    """
223
224
    # Validate Inputs
225
    _validate_input_range(threshold, "threshold", -1, 1)
226
    _validate_input_bool(colored, "colored")
227
228
    def color_negative_red(val):
229
        color = "#FF3344" if val < 0 else None
230
        return "color: %s" % color
231
232
    data = pd.DataFrame(data)
233
234
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
235
        target_data = []
236
        if isinstance(target, str):
237
            target_data = data[target]
238
            data = data.drop(target, axis=1)
239
240
        elif isinstance(target, (list, pd.Series, np.ndarray)):
241
            target_data = pd.Series(target)
242
            target = target_data.name
243
244
        corr = pd.DataFrame(data.corrwith(target_data, method=method))
245
        corr = corr.sort_values(corr.columns[0], ascending=False)
246
        corr.columns = [target]
247
248
    else:
249
        corr = data.corr(method=method)
250
251
    corr = _corr_selector(corr, split=split, threshold=threshold)
252
253
    if colored:
254
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
255
    return corr
256
257
258
# Correlation matrix / heatmap
259
def corr_plot(
260
    data: pd.DataFrame,
261
    split: Optional[str] = None,
262
    threshold: float = 0,
263
    target: Optional[Union[pd.Series, str]] = None,
264
    method: str = "pearson",
265
    cmap: str = "BrBG",
266
    figsize: Tuple = (12, 10),
267
    annot: bool = True,
268
    dev: bool = False,
269
    **kwargs,
270
):
271
    """ Two-dimensional visualization of the correlation between feature-columns \
272
        excluding NA values.
273
274
    Parameters
275
    ----------
276
    data : pd.DataFrame
277
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
278
        is provided, the index/column information is used to label the plots
279
    split : Optional[str], optional
280
        Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \
281
        None
282
            * None: visualize all correlations between the feature-columns
283
            * pos: visualize all positive correlations between the feature-columns \
284
                above the threshold
285
            * neg: visualize all negative correlations between the feature-columns \
286
                below the threshold
287
            * high: visualize all correlations between the feature-columns for \
288
                which abs (corr) > threshold is True
289
            * low: visualize all correlations between the feature-columns for which \
290
                abs(corr) < threshold is True
291
292
    threshold : float, optional
293
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
294
            split = "high" or split = "low", in which case default is 0.3
295
    target : Optional[Union[pd.Series, str]], optional
296
        Specify target for correlation. E.g. label column to generate only the \
297
        correlations between each feature and the label, by default None
298
    method : str, optional
299
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
300
            * pearson: measures linear relationships and requires normally \
301
                distributed and homoscedastic data.
302
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
303
            * kendall: ranked/ordinal correlation, measures monotonic relationships. \
304
                Computationally more expensive but more robust in smaller dataets \
305
                than "spearman".
306
307
    cmap : str, optional
308
        The mapping from data values to color space, matplotlib colormap name or \
309
        object, or list of colors, by default "BrBG"
310
    figsize : Tuple, optional
311
        Use to control the figure size, by default (12, 10)
312
    annot : bool, optional
313
        Use to show or hide annotations, by default True
314
    dev : bool, optional
315
        Display figure settings in the plot by setting dev = True. If False, the \
316
        settings are not displayed, by default False
317
318
    Keyword Arguments : optional
319
        Additional elements to control the visualization of the plot, e.g.:
320
321
            * mask: bool, default True
322
                If set to False the entire correlation matrix, including the upper \
323
                triangle is shown. Set dev = False in this case to avoid overlap.
324
            * vmax: float, default is calculated from the given correlation \
325
                coefficients.
326
                Value between -1 or vmin <= vmax <= 1, limits the range of the cbar.
327
            * vmin: float, default is calculated from the given correlation \
328
                coefficients.
329
                Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar.
330
            * linewidths: float, default 0.5
331
                Controls the line-width inbetween the squares.
332
            * annot_kws: dict, default {"size" : 10}
333
                Controls the font size of the annotations. Only available when \
334
                annot = True.
335
            * cbar_kws: dict, default {"shrink": .95, "aspect": 30}
336
                Controls the size of the colorbar.
337
            * Many more kwargs are available, i.e. "alpha" to control blending, or \
338
                options to adjust labels, ticks ...
339
340
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
341
342
    Returns
343
    -------
344
    ax: matplotlib Axes
345
        Returns the Axes object with the plot for further tweaking.
346
    """
347
348
    # Validate Inputs
349
    _validate_input_range(threshold, "threshold", -1, 1)
350
    _validate_input_bool(annot, "annot")
351
    _validate_input_bool(dev, "dev")
352
353
    data = pd.DataFrame(data)
354
355
    corr = corr_mat(
356
        data,
357
        split=split,
358
        threshold=threshold,
359
        target=target,
360
        method=method,
361
        colored=False,
362
    )
363
364
    mask = np.zeros_like(corr, dtype=np.bool)
365
366
    if target is None:
367
        mask = np.triu(np.ones_like(corr, dtype=np.bool))
368
369
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
370
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
371
372
    fig, ax = plt.subplots(figsize=figsize)
373
374
    # Specify kwargs for the heatmap
375
    kwargs = {
376
        "mask": mask,
377
        "cmap": cmap,
378
        "annot": annot,
379
        "vmax": vmax,
380
        "vmin": vmin,
381
        "linewidths": 0.5,
382
        "annot_kws": {"size": 10},
383
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
384
        **kwargs,
385
    }
386
387
    # Draw heatmap with mask and default settings
388
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
389
390
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
391
392
    # Settings
393
    if dev:
394
        fig.suptitle(
395
            f"\
396
            Settings (dev-mode): \n\
397
            - split-mode: {split} \n\
398
            - threshold: {threshold} \n\
399
            - method: {method} \n\
400
            - annotations: {annot} \n\
401
            - cbar: \n\
402
                - vmax: {vmax} \n\
403
                - vmin: {vmin} \n\
404
            - linewidths: {kwargs['linewidths']} \n\
405
            - annot_kws: {kwargs['annot_kws']} \n\
406
            - cbar_kws: {kwargs['cbar_kws']}",
407
            fontsize=12,
408
            color="gray",
409
            x=0.35,
410
            y=0.85,
411
            ha="left",
412
        )
413
414
    return ax
415
416
417
# Distribution plot
418
def dist_plot(
419
    data: pd.DataFrame,
420
    mean_color: str = "orange",
421
    size: int = 2.5,
422
    fill_range: Tuple = (0.025, 0.975),
423
    showall: bool = False,
424
    kde_kws: Dict[str, Any] = None,
425
    rug_kws: Dict[str, Any] = None,
426
    fill_kws: Dict[str, Any] = None,
427
    font_kws: Dict[str, Any] = None,
428
):
429
    """ Two-dimensional visualization of the distribution of non binary numerical features.
430
431
    Parameters
432
    ----------
433
    data : pd.DataFrame
434
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
435
        is provided, the index/column information is used to label the plots
436
    mean_color : str, optional
437
        Color of the vertical line indicating the mean of the data, by default "orange"
438
    size : int, optional
439
        Controls the plot size, by default 2.5
440
    fill_range : Tuple, optional
441
        Set the quantiles for shading. Default spans 95% of the data, which is about \
442
        two std. deviations above and below the mean, by default (0.025, 0.975)
443
    showall : bool, optional
444
        Set to True to remove the output limit of 20 plots, by default False
445
    kde_kws : Dict[str, Any], optional
446
        Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \
447
        "linewidth": 1.5, "bw_adjust": 0.8}
448
    rug_kws : Dict[str, Any], optional
449
        Keyword arguments for rugplot(), by default {"color": "#ff3333", \
450
        "alpha": 0.15, "lw": 3, "height": 0.075}
451
    fill_kws : Dict[str, Any], optional
452
        Keyword arguments to control the fill, by default {"color": "#80d4ff", \
453
        "alpha": 0.2}
454
    font_kws : Dict[str, Any], optional
455
        Keyword arguments to control the font, by default {"color":  "#111111", \
456
        "weight": "normal", "size": 11}
457
458
    Returns
459
    -------
460
    ax: matplotlib Axes
461
        Returns the Axes object with the plot for further tweaking.
462
    """
463
464
    # Validate Inputs
465
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
466
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
467
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
468
    _validate_input_bool(showall, "showall")
469
470
    # Handle dictionary defaults
471
    kde_kws = (
472
        {"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8}
473
        if kde_kws is None
474
        else kde_kws.copy()
475
    )
476
    rug_kws = (
477
        {"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075}
478
        if rug_kws is None
479
        else rug_kws.copy()
480
    )
481
    fill_kws = (
482
        {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy()
483
    )
484
    font_kws = (
485
        {"color": "#111111", "weight": "normal", "size": 11}
486
        if font_kws is None
487
        else font_kws.copy()
488
    )
489
490
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
491
    df = data.copy()
492
    data = data.loc[:, data.nunique() > 2]
493
    if data.shape[0] > 10000:
494
        data = data.sample(n=10000, random_state=408)
495
        print(
496
            "Large dataset detected, using 10000 random samples for the plots. Summary"
497
            " statistics are still based on the entire dataset."
498
        )
499
    cols = list(data.select_dtypes(include=["number"]).columns)
500
    data = data[cols]
501
502
    if len(cols) == 0:
503
        print("No columns with numeric data were detected.")
504
        return None
505
506
    if len(cols) >= 20 and showall is False:
507
        print(
508
            "Note: The number of non binary numerical features is very large "
509
            f"({len(cols)}), please consider splitting the data. Showing plots for "
510
            "the first 20 numerical features. Override this by setting showall=True."
511
        )
512
        cols = cols[:20]
513
514
    g = None
515
    for col in cols:
516
        col_data = data[col].dropna(axis=0)
517
        col_df = df[col].dropna(axis=0)
518
519
        g = sns.displot(
520
            col_data,
521
            kind="kde",
522
            rug=True,
523
            height=size,
524
            aspect=5,
525
            legend=False,
526
            rug_kws=rug_kws,
527
            **kde_kws,
528
        )
529
530
        # Vertical lines and fill
531
        x, y = g.axes[0, 0].lines[0].get_xydata().T
532
        g.axes[0, 0].fill_between(
533
            x,
534
            y,
535
            where=(
536
                (x >= np.quantile(col_df, fill_range[0]))
537
                & (x <= np.quantile(col_df, fill_range[1]))
538
            ),
539
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
540
            **fill_kws,
541
        )
542
543
        mean = np.mean(col_df)
544
        std = scipy.stats.tstd(col_df)
545
        g.axes[0, 0].vlines(
546
            x=mean,
547
            ymin=0,
548
            ymax=np.interp(mean, x, y),
549
            ls="dotted",
550
            color=mean_color,
551
            lw=2,
552
            label="mean",
553
        )
554
        g.axes[0, 0].vlines(
555
            x=np.median(col_df),
556
            ymin=0,
557
            ymax=np.interp(np.median(col_df), x, y),
558
            ls=":",
559
            color=".3",
560
            label="median",
561
        )
562
        g.axes[0, 0].vlines(
563
            x=[mean - std, mean + std],
564
            ymin=0,
565
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
566
            ls=":",
567
            color=".5",
568
            label="\u03BC \u00B1 \u03C3",
569
        )
570
571
        g.axes[0, 0].set_ylim(0)
572
        g.axes[0, 0].set_xlim(
573
            g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05,
574
            g.axes[0, 0].get_xlim()[1] * 1.03,
575
        )
576
577
        # Annotations and legend
578
        g.axes[0, 0].text(
579
            0.005,
580
            0.9,
581
            f"Mean: {mean:.2f}",
582
            fontdict=font_kws,
583
            transform=g.axes[0, 0].transAxes,
584
        )
585
        g.axes[0, 0].text(
586
            0.005,
587
            0.7,
588
            f"Std. dev: {std:.2f}",
589
            fontdict=font_kws,
590
            transform=g.axes[0, 0].transAxes,
591
        )
592
        g.axes[0, 0].text(
593
            0.005,
594
            0.5,
595
            f"Skew: {scipy.stats.skew(col_df):.2f}",
596
            fontdict=font_kws,
597
            transform=g.axes[0, 0].transAxes,
598
        )
599
        g.axes[0, 0].text(
600
            0.005,
601
            0.3,
602
            f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}",  # Excess Kurtosis
603
            fontdict=font_kws,
604
            transform=g.axes[0, 0].transAxes,
605
        )
606
        g.axes[0, 0].text(
607
            0.005,
608
            0.1,
609
            f"Count: {len(col_df)}",
610
            fontdict=font_kws,
611
            transform=g.axes[0, 0].transAxes,
612
        )
613
        g.axes[0, 0].legend(loc="upper right")
614
615
    return g.axes[0, 0]
616
617
618
# Missing value plot
619
def missingval_plot(
620
    data: pd.DataFrame,
621
    cmap: str = "PuBuGn",
622
    figsize: Tuple = (20, 20),
623
    sort: bool = False,
624
    spine_color: str = "#EEEEEE",
625
):
626
    """ Two-dimensional visualization of the missing values in a dataset.
627
628
    Parameters
629
    ----------
630
    data : pd.DataFrame
631
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
632
        is provided, the index/column information is used to label the plots
633
    cmap : str, optional
634
        Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \
635
        found in the matplotlib documentation, by default "PuBuGn"
636
    figsize : Tuple, optional
637
        Use to control the figure size, by default (20, 20)
638
    sort : bool, optional
639
        Sort columns based on missing values in descending order and drop columns \
640
        without any missing values, by default False
641
    spine_color : str, optional
642
        Set to "None" to hide the spines on all plots or use any valid matplotlib \
643
        color argument, by default "#EEEEEE"
644
645
    Returns
646
    -------
647
    GridSpec
648
        gs: Figure with array of Axes objects
649
    """
650
651
    # Validate Inputs
652
    _validate_input_bool(sort, "sort")
653
654
    data = pd.DataFrame(data)
655
656
    if sort:
657
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
658
        final_cols = (
659
            mv_cols_sorted.drop(
660
                mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()
661
            )
662
            .keys()
663
            .tolist()
664
        )
665
        data = data[final_cols]
666
        print("Displaying only columns with missing values.")
667
668
    # Identify missing values
669
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
670
    total_datapoints = data.shape[0] * data.shape[1]
671
672
    if mv_total == 0:
673
        print("No missing values found in the dataset.")
674
        return None
675
    else:
676
        # Create figure and axes
677
        fig = plt.figure(figsize=figsize)
678
        gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
679
        ax1 = fig.add_subplot(gs[:1, :5])
680
        ax2 = fig.add_subplot(gs[1:, :5])
681
        ax3 = fig.add_subplot(gs[:1, 5:])
682
        ax4 = fig.add_subplot(gs[1:, 5:])
683
684
        # ax1 - Barplot
685
        colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
686
        ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
687
        ax1.get_xaxis().set_visible(False)
688
        ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
689
        ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
690
        ax1.grid(linestyle=":", linewidth=1)
691
        ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0))
692
        ax1.tick_params(axis="y", colors="#111111", length=1)
693
694
        # annotate values on top of the bars
695
        for rect, label in zip(ax1.patches, mv_cols):
696
            height = rect.get_height()
697
            ax1.text(
698
                0.1 + rect.get_x() + rect.get_width() / 2,
699
                height + 0.5,
700
                label,
701
                ha="center",
702
                va="bottom",
703
                rotation="90",
704
                alpha=0.5,
705
                fontsize="11",
706
            )
707
708
        ax1.set_frame_on(True)
709
        for _, spine in ax1.spines.items():
710
            spine.set_visible(True)
711
            spine.set_color(spine_color)
712
        ax1.spines["top"].set_color(None)
713
714
        # ax2 - Heatmap
715
        sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
716
        ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1))
717
        ax2.set_yticklabels(ax2.get_yticks())
718
        ax2.set_xticklabels(
719
            ax2.get_xticklabels(),
720
            horizontalalignment="center",
721
            fontweight="light",
722
            fontsize="12",
723
        )
724
        ax2.tick_params(length=1, colors="#111111")
725
        for _, spine in ax2.spines.items():
726
            spine.set_visible(True)
727
            spine.set_color(spine_color)
728
729
        # ax3 - Summary
730
        fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
731
        ax3.get_xaxis().set_visible(False)
732
        ax3.get_yaxis().set_visible(False)
733
        ax3.set(frame_on=False)
734
735
        ax3.text(
736
            0.025,
737
            0.875,
738
            f"Total: {np.round(total_datapoints/1000,1)}K",
739
            transform=ax3.transAxes,
740
            fontdict=fontax3,
741
        )
742
        ax3.text(
743
            0.025,
744
            0.675,
745
            f"Missing: {np.round(mv_total/1000,1)}K",
746
            transform=ax3.transAxes,
747
            fontdict=fontax3,
748
        )
749
        ax3.text(
750
            0.025,
751
            0.475,
752
            f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
753
            transform=ax3.transAxes,
754
            fontdict=fontax3,
755
        )
756
        ax3.text(
757
            0.025,
758
            0.275,
759
            f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
760
            transform=ax3.transAxes,
761
            fontdict=fontax3,
762
        )
763
        ax3.text(
764
            0.025,
765
            0.075,
766
            f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
767
            transform=ax3.transAxes,
768
            fontdict=fontax3,
769
        )
770
771
        # ax4 - Scatter plot
772
        ax4.get_yaxis().set_visible(False)
773
        for _, spine in ax4.spines.items():
774
            spine.set_color(spine_color)
775
        ax4.tick_params(axis="x", colors="#111111", length=1)
776
777
        ax4.scatter(
778
            mv_rows,
779
            range(len(mv_rows)),
780
            s=mv_rows,
781
            c=mv_rows,
782
            cmap=cmap,
783
            marker=".",
784
            vmin=1,
785
        )
786
        ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
787
        ax4.set_xlim(0, max(mv_rows) + 0.5)
788
        ax4.grid(linestyle=":", linewidth=1)
789
790
        gs.figure.suptitle(
791
            "Missing value plot", x=0.45, y=0.94, fontsize=18, color="#111111"
792
        )
793
794
        return gs
795