GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Test Failed
Push — main ( 698010...a51334 )
by Andreas
01:56
created

klib.describe.cat_plot()   D

Complexity

Conditions 10

Size

Total Lines 141
Code Lines 86

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 86
nop 6
dl 0
loc 141
rs 4.6581
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like klib.describe.cat_plot() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
Functions for descriptive analytics.
3
4
:author: Andreas Kanz
5
6
"""
7
from __future__ import annotations
8
9
from typing import Any
10
from typing import Literal
11
from typing import Optional
12
13
import matplotlib.pyplot as plt
14
import numpy as np
15
import pandas as pd
16
import scipy
17
import seaborn as sns
18
from matplotlib import ticker
19
from matplotlib.colors import LinearSegmentedColormap
20
from matplotlib.colors import to_rgb
21
22
from klib.utils import _corr_selector
23
from klib.utils import _missing_vals
24
from klib.utils import _validate_input_bool
25
from klib.utils import _validate_input_int
26
from klib.utils import _validate_input_num_data
27
from klib.utils import _validate_input_range
28
from klib.utils import _validate_input_smaller
29
from klib.utils import _validate_input_sum_larger
30
31
__all__ = ["cat_plot", "corr_mat", "corr_plot", "dist_plot", "missingval_plot"]
32
33
34
def cat_plot(
35
    data: pd.DataFrame,
36
    figsize: tuple[float, float] = (18, 18),
37
    top: int = 3,
38
    bottom: int = 3,
39
    bar_color_top: str = "#5ab4ac",
40
    bar_color_bottom: str = "#d8b365",
41
):
42
    """Two-dimensional visualization of the number and frequency of categorical \
43
        features.
44
45
    Parameters
46
    ----------
47
    data : pd.DataFrame
48
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
49
        is provided, the index/column information is used to label the plots
50
    figsize : tuple[float, float], optional
51
        Use to control the figure size, by default (18, 18)
52
    top : int, optional
53
        Show the "top" most frequent values in a column, by default 3
54
    bottom : int, optional
55
        Show the "bottom" most frequent values in a column, by default 3
56
    bar_color_top : str, optional
57
        Use to control the color of the bars indicating the most common values, by \
58
        default "#5ab4ac"
59
    bar_color_bottom : str, optional
60
        Use to control the color of the bars indicating the least common values, by \
61
        default "#d8b365"
62
63
    Returns
64
    -------
65
    Gridspec
66
        gs: Figure with array of Axes objects
67
    """
68
    # Validate Inputs
69
    _validate_input_int(top, "top")
70
    _validate_input_int(bottom, "bottom")
71
    _validate_input_range(top, "top", 0, data.shape[1])
72
    _validate_input_range(bottom, "bottom", 0, data.shape[1])
73
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
74
75
    data = pd.DataFrame(data).copy()
76
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
77
    data = data[cols]
78
79
    if len(cols) == 0:
80
        print("No columns with categorical data were detected.")
81
        return None
82
83
    for col in data.columns:
84
        if data[col].dtype.name in ("category", "string"):
85
            data[col] = data[col].astype("object")
86
87
    fig = plt.figure(figsize=figsize)
88
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
89
90
    for count, col in enumerate(cols):
91
        n_unique = data[col].nunique(dropna=True)
92
        value_counts = data[col].value_counts()
93
        lim_top, lim_bot = top, bottom
94
95
        if n_unique < top + bottom:
96
            if bottom > top:
97
                lim_top = min(int(n_unique // 2), top)
98
                lim_bot = n_unique - lim_top
99
            else:
100
                lim_bot = min(int(n_unique // 2), bottom)
101
                lim_top = n_unique - lim_bot
102
103
        value_counts_top = value_counts[:lim_top]
104
        value_counts_idx_top = value_counts_top.index.tolist()
105
        value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame()
106
        value_counts_idx_bot = value_counts_bot.index.tolist()
107
108
        if top == 0:
109
            value_counts_top = value_counts_idx_top = []
110
111
        if bottom == 0:
112
            value_counts_bot = value_counts_idx_bot = []
113
114
        data.loc[data[col].isin(value_counts_idx_top), col] = 10
115
        data.loc[data[col].isin(value_counts_idx_bot), col] = 0
116
        data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5
117
        data[col] = data[col].rolling(2, min_periods=1).mean()
118
119
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
120
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
121
        sum_top = sum(value_counts_top)
122
        sum_bot = sum(value_counts_bot)
123
124
        # Barcharts
125
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
126
        ax_top.bar(
127
            value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85
128
        )
129
        ax_top.bar(
130
            value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85
131
        )
132
        ax_top.set(frame_on=False)
133
        ax_top.tick_params(axis="x", labelrotation=90)
134
135
        # Summary stats
136
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
137
        plt.subplots_adjust(hspace=0.075)
138
        ax_bottom.get_yaxis().set_visible(False)
139
        ax_bottom.get_xaxis().set_visible(False)
140
        ax_bottom.set(frame_on=False)
141
        ax_bottom.text(
142
            0,
143
            0,
144
            f"Unique values: {n_unique}\n\n"
145
            f"Top {lim_top}: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n"
146
            f"Bot {lim_bot}: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)",
147
            transform=ax_bottom.transAxes,
148
            color="#111111",
149
            fontsize=11,
150
        )
151
152
    # Heatmap
153
    color_bot_rgb = to_rgb(bar_color_bottom)
154
    color_white = to_rgb("#FFFFFF")
155
    color_top_rgb = to_rgb(bar_color_top)
156
    cat_plot_cmap = LinearSegmentedColormap.from_list(
157
        "cat_plot_cmap", [color_bot_rgb, color_white, color_top_rgb], N=200
158
    )
159
    ax_hm = fig.add_subplot(gs[2:, :])
160
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm)
161
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[::5], -1))
162
    ax_hm.set_yticklabels(ax_hm.get_yticks())
163
    ax_hm.set_xticklabels(
164
        ax_hm.get_xticklabels(),
165
        horizontalalignment="center",
166
        fontweight="light",
167
        fontsize="medium",
168
    )
169
    ax_hm.tick_params(length=1, colors="#111111")
170
    gs.figure.suptitle(
171
        "Categorical data plot", x=0.5, y=0.91, fontsize=18, color="#111111"
172
    )
173
174
    return gs
175
176
177
def corr_mat(
178
    data: pd.DataFrame,
179
    split: Optional[Literal["pos", "neg", "high", "low"]] = None,
180
    threshold: float = 0,
181
    target: Optional[pd.DataFrame | pd.Series | np.ndarray | str] = None,
182
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
183
    colored: bool = True,
184
) -> pd.DataFrame | pd.Series:
185
    """Return a color-encoded correlation matrix.
186
187
    Parameters
188
    ----------
189
    data : pd.DataFrame
190
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
191
        is provided, the index/column information is used to label the plots
192
    split : Optional[Literal['pos', 'neg', 'high', 'low']], optional
193
        Type of split to be performed, by default None
194
        {None, "pos", "neg", "high", "low"}
195
    threshold : float, optional
196
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
197
        split = "high" or split = "low", in which case default is 0.3
198
    target : Optional[pd.DataFrame | str], optional
199
        Specify target for correlation. E.g. label column to generate only the \
200
        correlations between each feature and the label, by default None
201
    method : Literal['pearson', 'spearman', 'kendall'], optional
202
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
203
        * pearson: measures linear relationships and requires normally distributed \
204
            and homoscedastic data.
205
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
206
        * kendall: ranked/ordinal correlation, measures monotonic relationships. \
207
            Computationally more expensive but more robust in smaller dataets than \
208
            "spearman"
209
    colored : bool, optional
210
        If True the negative values in the correlation matrix are colored in red, by \
211
        default True
212
213
    Returns
214
    -------
215
    pd.DataFrame | pd.Styler
216
        If colored = True - corr: Pandas Styler object
217
        If colored = False - corr: Pandas DataFrame
218
    """
219
    # Validate Inputs
220
    _validate_input_range(threshold, "threshold", -1, 1)
221
    _validate_input_bool(colored, "colored")
222
223
    def color_negative_red(val: float) -> str:
224
        color = "#FF3344" if val < 0 else None
225
        return f"color: {color}"
226
227
    data = pd.DataFrame(data)
228
229
    _validate_input_num_data(data, "data")
230
231
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
232
        target_data = []
233
        if isinstance(target, str):
234
            target_data = data[target]
235
            data = data.drop(target, axis=1)
236
237
        elif isinstance(target, (list, pd.Series, np.ndarray)):
238
            target_data = pd.Series(target)
239
            target = target_data.name
240
241
        corr = pd.DataFrame(data.corrwith(target_data, method=method))
242
        corr = corr.sort_values(corr.columns[0], ascending=False)
243
        corr.columns = [target]
244
245
    else:
246
        corr = data.corr(method=method)
247
248
    corr = _corr_selector(corr, split=split, threshold=threshold)
249
250
    if colored:
251
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
252
    return corr
253
254
255
def corr_plot(
256
    data: pd.DataFrame,
257
    split: Optional[Literal["pos", "neg", "high", "low"]] = None,
258
    threshold: float = 0,
259
    target: Optional[pd.Series | str] = None,
260
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
261
    cmap: str = "BrBG",
262
    figsize: tuple[float, float] = (12, 10),
263
    annot: bool = True,
264
    dev: bool = False,
265
    **kwargs,
266
):
267
    """Two-dimensional visualization of the correlation between feature-columns \
268
        excluding NA values.
269
270
    Parameters
271
    ----------
272
    data : pd.DataFrame
273
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
274
        is provided, the index/column information is used to label the plots
275
    split : Optional[str], optional
276
        Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \
277
        None
278
            * None: visualize all correlations between the feature-columns
279
            * pos: visualize all positive correlations between the feature-columns \
280
                above the threshold
281
            * neg: visualize all negative correlations between the feature-columns \
282
                below the threshold
283
            * high: visualize all correlations between the feature-columns for \
284
                which abs (corr) > threshold is True
285
            * low: visualize all correlations between the feature-columns for which \
286
                abs(corr) < threshold is True
287
288
    threshold : float, optional
289
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
290
            split = "high" or split = "low", in which case default is 0.3
291
    target : Optional[pd.Series | str], optional
292
        Specify target for correlation. E.g. label column to generate only the \
293
        correlations between each feature and the label, by default None
294
    method : Literal['pearson', 'spearman', 'kendall'], optional
295
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
296
            * pearson: measures linear relationships and requires normally \
297
                distributed and homoscedastic data.
298
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
299
            * kendall: ranked/ordinal correlation, measures monotonic relationships. \
300
                Computationally more expensive but more robust in smaller dataets \
301
                than "spearman".
302
303
    cmap : str, optional
304
        The mapping from data values to color space, matplotlib colormap name or \
305
        object, or list of colors, by default "BrBG"
306
    figsize : tuple[float, float], optional
307
        Use to control the figure size, by default (12, 10)
308
    annot : bool, optional
309
        Use to show or hide annotations, by default True
310
    dev : bool, optional
311
        Display figure settings in the plot by setting dev = True. If False, the \
312
        settings are not displayed, by default False
313
314
    Keyword Arguments : optional
315
        Additional elements to control the visualization of the plot, e.g.:
316
317
            * mask: bool, default True
318
                If set to False the entire correlation matrix, including the upper \
319
                triangle is shown. Set dev = False in this case to avoid overlap.
320
            * vmax: float, default is calculated from the given correlation \
321
                coefficients.
322
                Value between -1 or vmin <= vmax <= 1, limits the range of the cbar.
323
            * vmin: float, default is calculated from the given correlation \
324
                coefficients.
325
                Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar.
326
            * linewidths: float, default 0.5
327
                Controls the line-width inbetween the squares.
328
            * annot_kws: dict, default {"size" : 10}
329
                Controls the font size of the annotations. Only available when \
330
                annot = True.
331
            * cbar_kws: dict, default {"shrink": .95, "aspect": 30}
332
                Controls the size of the colorbar.
333
            * Many more kwargs are available, i.e. "alpha" to control blending, or \
334
                options to adjust labels, ticks ...
335
336
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
337
338
    Returns
339
    -------
340
    ax: matplotlib Axes
341
        Returns the Axes object with the plot for further tweaking.
342
    """
343
    # Validate Inputs
344
    _validate_input_range(threshold, "threshold", -1, 1)
345
    _validate_input_bool(annot, "annot")
346
    _validate_input_bool(dev, "dev")
347
348
    data = pd.DataFrame(data)
349
350
    corr = corr_mat(
351
        data,
352
        split=split,
353
        threshold=threshold,
354
        target=target,
355
        method=method,
356
        colored=False,
357
    )
358
359
    mask = np.zeros_like(corr, dtype=bool)
360
361
    if target is None:
362
        mask = np.triu(np.ones_like(corr, dtype=bool))
363
364
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
365
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
366
367
    fig, ax = plt.subplots(figsize=figsize)
368
369
    # Specify kwargs for the heatmap
370
    kwargs = {
371
        "mask": mask,
372
        "cmap": cmap,
373
        "annot": annot,
374
        "vmax": vmax,
375
        "vmin": vmin,
376
        "linewidths": 0.5,
377
        "annot_kws": {"size": 10},
378
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
379
        **kwargs,
380
    }
381
382
    # Draw heatmap with mask and default settings
383
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
384
385
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
386
387
    # Settings
388
    if dev:
389
        fig.suptitle(
390
            f"\
391
            Settings (dev-mode): \n\
392
            - split-mode: {split} \n\
393
            - threshold: {threshold} \n\
394
            - method: {method} \n\
395
            - annotations: {annot} \n\
396
            - cbar: \n\
397
                - vmax: {vmax} \n\
398
                - vmin: {vmin} \n\
399
            - linewidths: {kwargs['linewidths']} \n\
400
            - annot_kws: {kwargs['annot_kws']} \n\
401
            - cbar_kws: {kwargs['cbar_kws']}",
402
            fontsize=12,
403
            color="gray",
404
            x=0.35,
405
            y=0.85,
406
            ha="left",
407
        )
408
409
    return ax
410
411
412
def dist_plot(
413
    data: pd.DataFrame,
414
    mean_color: str = "orange",
415
    size: int = 3,
416
    fill_range: tuple = (0.025, 0.975),
417
    showall: bool = False,
418
    kde_kws: Optional[dict[str, Any]] = None,
419
    rug_kws: Optional[dict[str, Any]] = None,
420
    fill_kws: Optional[dict[str, Any]] = None,
421
    font_kws: Optional[dict[str, Any]] = None,
422
):
423
    """Two-dimensional visualization of the distribution of non binary numerical \
424
        features.
425
426
    Parameters
427
    ----------
428
    data : pd.DataFrame
429
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
430
        is provided, the index/column information is used to label the plots
431
    mean_color : str, optional
432
        Color of the vertical line indicating the mean of the data, by default "orange"
433
    size : float, optional
434
        Controls the plot size, by default 3
435
    fill_range : tuple, optional
436
        Set the quantiles for shading. Default spans 95% of the data, which is about \
437
        two std. deviations above and below the mean, by default (0.025, 0.975)
438
    showall : bool, optional
439
        Set to True to remove the output limit of 20 plots, by default False
440
    kde_kws : dict[str, Any], optional
441
        Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \
442
        "linewidth": 1.5, "bw_adjust": 0.8}
443
    rug_kws : dict[str, Any], optional
444
        Keyword arguments for rugplot(), by default {"color": "#ff3333", \
445
        "alpha": 0.15, "lw": 3, "height": 0.075}
446
    fill_kws : dict[str, Any], optional
447
        Keyword arguments to control the fill, by default {"color": "#80d4ff", \
448
        "alpha": 0.2}
449
    font_kws : dict[str, Any], optional
450
        Keyword arguments to control the font, by default {"color":  "#111111", \
451
        "weight": "normal", "size": 11}
452
453
    Returns
454
    -------
455
    ax: matplotlib Axes
456
        Returns the Axes object with the plot for further tweaking.
457
    """
458
    # Validate Inputs
459
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
460
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
461
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
462
    _validate_input_bool(showall, "showall")
463
464
    # Handle dictionary defaults
465
    kde_kws = (
466
        {"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8}
467
        if kde_kws is None
468
        else kde_kws.copy()
469
    )
470
    rug_kws = (
471
        {"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075}
472
        if rug_kws is None
473
        else rug_kws.copy()
474
    )
475
    fill_kws = (
476
        {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy()
477
    )
478
    font_kws = (
479
        {"color": "#111111", "weight": "normal", "size": 11}
480
        if font_kws is None
481
        else font_kws.copy()
482
    )
483
484
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
485
    df = data.copy()
486
    data = data.loc[:, data.nunique() > 2]
487
    if data.shape[0] > 10000:
488
        data = data.sample(n=10000, random_state=408)
489
        print(
490
            "Large dataset detected, using 10000 random samples for the plots. Summary"
491
            " statistics are still based on the entire dataset."
492
        )
493
    cols = list(data.select_dtypes(include=["number"]).columns)
494
    data = data[cols]
495
496
    if not cols:
497
        print("No columns with numeric data were detected.")
498
        return None
499
500
    if len(cols) >= 20 and not showall:
501
        print(
502
            "Note: The number of non binary numerical features is very large "
503
            f"({len(cols)}), please consider splitting the data. Showing plots for "
504
            "the first 20 numerical features. Override this by setting showall=True."
505
        )
506
        cols = cols[:20]
507
508
    g = None
509
    for col in cols:
510
        col_data = data[col].dropna(axis=0)
511
        col_df = df[col].dropna(axis=0)
512
513
        g = sns.displot(
514
            col_data,
515
            kind="kde",
516
            rug=True,
517
            height=size,
518
            aspect=5,
519
            legend=False,
520
            rug_kws=rug_kws,
521
            **kde_kws,
522
        )
523
524
        # Vertical lines and fill
525
        x, y = g.axes[0, 0].lines[0].get_xydata().T
526
        g.axes[0, 0].fill_between(
527
            x,
528
            y,
529
            where=(
530
                (x >= np.quantile(col_df, fill_range[0]))
531
                & (x <= np.quantile(col_df, fill_range[1]))
532
            ),
533
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
534
            **fill_kws,
535
        )
536
537
        mean = np.mean(col_df)
538
        std = scipy.stats.tstd(col_df)
539
        g.axes[0, 0].vlines(
540
            x=mean,
541
            ymin=0,
542
            ymax=np.interp(mean, x, y),
543
            ls="dotted",
544
            color=mean_color,
545
            lw=2,
546
            label="mean",
547
        )
548
        g.axes[0, 0].vlines(
549
            x=np.median(col_df),
550
            ymin=0,
551
            ymax=np.interp(np.median(col_df), x, y),
552
            ls=":",
553
            color=".3",
554
            label="median",
555
        )
556
        g.axes[0, 0].vlines(
557
            x=[mean - std, mean + std],
558
            ymin=0,
559
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
560
            ls=":",
561
            color=".5",
562
            label="\u03BC \u00B1 \u03C3",
563
        )
564
565
        g.axes[0, 0].set_ylim(0)
566
        g.axes[0, 0].set_xlim(
567
            g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05,
568
            g.axes[0, 0].get_xlim()[1] * 1.03,
569
        )
570
571
        # Annotations and legend
572
        g.axes[0, 0].text(
573
            0.005,
574
            0.9,
575
            f"Mean: {mean:.2f}",
576
            fontdict=font_kws,
577
            transform=g.axes[0, 0].transAxes,
578
        )
579
        g.axes[0, 0].text(
580
            0.005,
581
            0.7,
582
            f"Std. dev: {std:.2f}",
583
            fontdict=font_kws,
584
            transform=g.axes[0, 0].transAxes,
585
        )
586
        g.axes[0, 0].text(
587
            0.005,
588
            0.5,
589
            f"Skew: {scipy.stats.skew(col_df):.2f}",
590
            fontdict=font_kws,
591
            transform=g.axes[0, 0].transAxes,
592
        )
593
        g.axes[0, 0].text(
594
            0.005,
595
            0.3,
596
            f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}",  # Excess Kurtosis
597
            fontdict=font_kws,
598
            transform=g.axes[0, 0].transAxes,
599
        )
600
        g.axes[0, 0].text(
601
            0.005,
602
            0.1,
603
            f"Count: {len(col_df)}",
604
            fontdict=font_kws,
605
            transform=g.axes[0, 0].transAxes,
606
        )
607
        g.axes[0, 0].legend(loc="upper right")
608
609
    return g.axes[0, 0]
610
611
612
def missingval_plot(
613
    data: pd.DataFrame,
614
    cmap: str = "PuBuGn",
615
    figsize: tuple = (20, 20),
616
    sort: bool = False,
617
    spine_color: str = "#EEEEEE",
618
):
619
    """Two-dimensional visualization of the missing values in a dataset.
620
621
    Parameters
622
    ----------
623
    data : pd.DataFrame
624
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
625
        is provided, the index/column information is used to label the plots
626
    cmap : str, optional
627
        Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \
628
        found in the matplotlib documentation, by default "PuBuGn"
629
    figsize : tuple, optional
630
        Use to control the figure size, by default (20, 20)
631
    sort : bool, optional
632
        Sort columns based on missing values in descending order and drop columns \
633
        without any missing values, by default False
634
    spine_color : str, optional
635
        Set to "None" to hide the spines on all plots or use any valid matplotlib \
636
        color argument, by default "#EEEEEE"
637
638
    Returns
639
    -------
640
    GridSpec
641
        gs: Figure with array of Axes objects
642
    """
643
    # Validate Inputs
644
    _validate_input_bool(sort, "sort")
645
646
    data = pd.DataFrame(data)
647
648
    if sort:
649
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
650
        final_cols = (
651
            mv_cols_sorted.drop(
652
                mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist()
653
            )
654
            .keys()
655
            .tolist()
656
        )
657
        data = data[final_cols]
658
        print("Displaying only columns with missing values.")
659
660
    # Identify missing values
661
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
662
    total_datapoints = data.shape[0] * data.shape[1]
663
664
    if mv_total == 0:
665
        print("No missing values found in the dataset.")
666
        return None
667
668
    # Create figure and axes
669
    fig = plt.figure(figsize=figsize)
670
    gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
671
    ax1 = fig.add_subplot(gs[:1, :5])
672
    ax2 = fig.add_subplot(gs[1:, :5])
673
    ax3 = fig.add_subplot(gs[:1, 5:])
674
    ax4 = fig.add_subplot(gs[1:, 5:])
675
676
    # ax1 - Barplot
677
    colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
678
    ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
679
    ax1.get_xaxis().set_visible(False)
680
    ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
681
    ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
682
    ax1.grid(linestyle=":", linewidth=1)
683
    ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=1))
684
    ax1.tick_params(axis="y", colors="#111111", length=1)
685
686
    # annotate values on top of the bars
687
    for rect, label in zip(ax1.patches, mv_cols):
688
        height = rect.get_height()
689
        ax1.text(
690
            rect.get_x() + rect.get_width() / 2,
691
            height + max(np.log(1 + height / 6), 0.075),
692
            label,
693
            ha="center",
694
            va="bottom",
695
            rotation="90",
696
            alpha=0.5,
697
            fontsize="11",
698
        )
699
700
    ax1.set_frame_on(True)
701
    for _, spine in ax1.spines.items():
702
        spine.set_visible(True)
703
        spine.set_color(spine_color)
704
    ax1.spines["top"].set_color(None)
705
706
    # ax2 - Heatmap
707
    sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
708
    ax2.set_yticks(np.round(ax2.get_yticks()[::5], -1))
709
    ax2.set_yticklabels(ax2.get_yticks())
710
    ax2.set_xticklabels(
711
        ax2.get_xticklabels(),
712
        horizontalalignment="center",
713
        fontweight="light",
714
        fontsize="12",
715
    )
716
    ax2.tick_params(length=1, colors="#111111")
717
    for _, spine in ax2.spines.items():
718
        spine.set_visible(True)
719
        spine.set_color(spine_color)
720
721
    # ax3 - Summary
722
    fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
723
    ax3.get_xaxis().set_visible(False)
724
    ax3.get_yaxis().set_visible(False)
725
    ax3.set(frame_on=False)
726
727
    ax3.text(
728
        0.025,
729
        0.875,
730
        f"Total: {np.round(total_datapoints/1000,1)}K",
731
        transform=ax3.transAxes,
732
        fontdict=fontax3,
733
    )
734
    ax3.text(
735
        0.025,
736
        0.675,
737
        f"Missing: {np.round(mv_total/1000,1)}K",
738
        transform=ax3.transAxes,
739
        fontdict=fontax3,
740
    )
741
    ax3.text(
742
        0.025,
743
        0.475,
744
        f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
745
        transform=ax3.transAxes,
746
        fontdict=fontax3,
747
    )
748
    ax3.text(
749
        0.025,
750
        0.275,
751
        f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
752
        transform=ax3.transAxes,
753
        fontdict=fontax3,
754
    )
755
    ax3.text(
756
        0.025,
757
        0.075,
758
        f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
759
        transform=ax3.transAxes,
760
        fontdict=fontax3,
761
    )
762
763
    # ax4 - Scatter plot
764
    ax4.get_yaxis().set_visible(False)
765
    for _, spine in ax4.spines.items():
766
        spine.set_color(spine_color)
767
    ax4.tick_params(axis="x", colors="#111111", length=1)
768
769
    ax4.scatter(
770
        mv_rows,
771
        range(len(mv_rows)),
772
        s=mv_rows,
773
        c=mv_rows,
774
        cmap=cmap,
775
        marker=".",
776
        vmin=1,
777
    )
778
    ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
779
    ax4.set_xlim(0, max(mv_rows) + 0.5)
780
    ax4.grid(linestyle=":", linewidth=1)
781
782
    gs.figure.suptitle(
783
        "Missing value plot", x=0.45, y=0.94, fontsize=18, color="#111111"
784
    )
785
786
    return gs
787