GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Test Failed
Push — main ( 496729...3d130d )
by Andreas
04:22
created

klib.describe.corr_interactive_plot()   C

Complexity

Conditions 7

Size

Total Lines 211
Code Lines 66

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 66
nop 9
dl 0
loc 211
rs 6.7127
c 0
b 0
f 0

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for descriptive analytics.
2
3
:author: Andreas Kanz
4
5
"""
6
from __future__ import annotations
7
8
from typing import Any
9
from typing import Literal
10
11
import matplotlib.pyplot as plt
12
import numpy as np
13
import pandas as pd
14
import scipy
15
import seaborn as sns
16
from matplotlib import ticker
17
from matplotlib.colors import LinearSegmentedColormap
18
from matplotlib.colors import to_rgb
19
from matplotlib.gridspec import GridSpec  # noqa: TCH002
20
21
import plotly.graph_objects as go
22
from screeninfo import get_monitors
23
24
from klib.utils import _corr_selector
25
from klib.utils import _missing_vals
26
from klib.utils import _validate_input_bool
27
from klib.utils import _validate_input_int
28
from klib.utils import _validate_input_num_data
29
from klib.utils import _validate_input_range
30
from klib.utils import _validate_input_smaller
31
from klib.utils import _validate_input_sum_larger
32
33
__all__ = [
34
    "cat_plot",
35
    "corr_mat",
36
    "corr_plot",
37
    "corr_interactive_plot",
38
    "dist_plot",
39
    "missingval_plot"
40
]
41
42
43
def cat_plot(  # noqa: C901, PLR0915
44
    data: pd.DataFrame,
45
    figsize: tuple[float, float] = (18, 18),
46
    top: int = 3,
47
    bottom: int = 3,
48
    bar_color_top: str = "#5ab4ac",
49
    bar_color_bottom: str = "#d8b365",
50
) -> GridSpec:
51
    """Two-dimensional visualization of number and frequency of categorical features.
52
53
    Parameters
54
    ----------
55
    data : pd.DataFrame
56
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
57
        is provided, the index/column information is used to label the plots
58
    figsize : tuple[float, float], optional
59
        Use to control the figure size, by default (18, 18)
60
    top : int, optional
61
        Show the "top" most frequent values in a column, by default 3
62
    bottom : int, optional
63
        Show the "bottom" most frequent values in a column, by default 3
64
    bar_color_top : str, optional
65
        Use to control the color of the bars indicating the most common values, by \
66
        default "#5ab4ac"
67
    bar_color_bottom : str, optional
68
        Use to control the color of the bars indicating the least common values, by \
69
        default "#d8b365"
70
71
    Returns
72
    -------
73
    Gridspec
74
        gs: Figure with array of Axes objects
75
    """
76
    # Validate Inputs
77
    _validate_input_int(top, "top")
78
    _validate_input_int(bottom, "bottom")
79
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
80
81
    data = pd.DataFrame(data).copy()
82
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
83
    data = data[cols]
84
85
    if len(cols) == 0:
86
        print("No columns with categorical data were detected.")
87
        return None
88
89
    for col in data.columns:
90
        if data[col].dtype.name in ("category", "string"):
91
            data[col] = data[col].astype("object")
92
93
    fig = plt.figure(figsize=figsize)
94
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
95
96
    for count, col in enumerate(cols):
97
        n_unique = data[col].nunique(dropna=True)
98
        value_counts = data[col].value_counts()
99
        lim_top, lim_bot = top, bottom
100
101
        if n_unique < top + bottom:
102
            if bottom > top:
103
                lim_top = min(int(n_unique // 2), top)
104
                lim_bot = n_unique - lim_top
105
            else:
106
                lim_bot = min(int(n_unique // 2), bottom)
107
                lim_top = n_unique - lim_bot
108
109
        value_counts_top = value_counts[:lim_top]
110
        value_counts_idx_top = value_counts_top.index.tolist()
111
        value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame()
112
        value_counts_idx_bot = value_counts_bot.index.tolist()
113
114
        if top == 0:
115
            value_counts_top = value_counts_idx_top = []
116
117
        if bottom == 0:
118
            value_counts_bot = value_counts_idx_bot = []
119
120
        data.loc[data[col].isin(value_counts_idx_top), col] = 10
121
        data.loc[data[col].isin(value_counts_idx_bot), col] = 0
122
        data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5  # noqa: PLR2004
123
        data[col] = data[col].rolling(2, min_periods=1).mean()
124
125
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
126
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
127
        sum_top = sum(value_counts_top)
128
        sum_bot = sum(value_counts_bot)
129
130
        # Barcharts
131
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
132
        ax_top.bar(
133
            value_counts_idx_top,
134
            value_counts_top,
135
            color=bar_color_top,
136
            width=0.85,
137
        )
138
        ax_top.bar(
139
            value_counts_idx_bot,
140
            value_counts_bot,
141
            color=bar_color_bottom,
142
            width=0.85,
143
        )
144
        ax_top.set(frame_on=False)
145
        ax_top.tick_params(axis="x", labelrotation=90)
146
147
        # Summary stats
148
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
149
        plt.subplots_adjust(hspace=0.075)
150
        ax_bottom.get_yaxis().set_visible(False)  # noqa: FBT003
151
        ax_bottom.get_xaxis().set_visible(False)  # noqa: FBT003
152
        ax_bottom.set(frame_on=False)
153
        ax_bottom.text(
154
            0,
155
            0,
156
            f"Unique values: {n_unique}\n\n"
157
            f"Top {lim_top}: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n"
158
            f"Bot {lim_bot}: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)",
159
            transform=ax_bottom.transAxes,
160
            color="#111111",
161
            fontsize=11,
162
        )
163
164
    # Heatmap
165
    color_bot_rgb = to_rgb(bar_color_bottom)
166
    color_white = to_rgb("#FFFFFF")
167
    color_top_rgb = to_rgb(bar_color_top)
168
    cat_plot_cmap = LinearSegmentedColormap.from_list(
169
        "cat_plot_cmap",
170
        [color_bot_rgb, color_white, color_top_rgb],
171
        N=200,
172
    )
173
    ax_hm = fig.add_subplot(gs[2:, :])
174
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm)
175
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[::5], -1))
176
    ax_hm.set_yticklabels(ax_hm.get_yticks())
177
    ax_hm.set_xticklabels(
178
        ax_hm.get_xticklabels(),
179
        horizontalalignment="center",
180
        fontweight="light",
181
        fontsize="medium",
182
    )
183
    ax_hm.tick_params(length=1, colors="#111111")
184
    gs.figure.suptitle(
185
        "Categorical data plot",
186
        x=0.5,
187
        y=0.91,
188
        fontsize=18,
189
        color="#111111",
190
    )
191
192
    return gs
193
194
195
def corr_mat(
196
    data: pd.DataFrame,
197
    split: Literal["pos", "neg", "high", "low"] | None = None,
198
    threshold: float = 0,
199
    target: pd.DataFrame | pd.Series | np.ndarray | str | None = None,
200
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
201
    colored: bool = True,
202
) -> pd.DataFrame | pd.Series:
203
    """Return a color-encoded correlation matrix.
204
205
    Parameters
206
    ----------
207
    data : pd.DataFrame
208
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
209
        is provided, the index/column information is used to label the plots
210
    split : Optional[Literal['pos', 'neg', 'high', 'low']], optional
211
        Type of split to be performed, by default None
212
        {None, "pos", "neg", "high", "low"}
213
    threshold : float, optional
214
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
215
        split = "high" or split = "low", in which case default is 0.3
216
    target : Optional[pd.DataFrame | str], optional
217
        Specify target for correlation. E.g. label column to generate only the \
218
        correlations between each feature and the label, by default None
219
    method : Literal['pearson', 'spearman', 'kendall'], optional
220
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
221
        * pearson: measures linear relationships and requires normally distributed \
222
            and homoscedastic data.
223
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
224
        * kendall: ranked/ordinal correlation, measures monotonic relationships. \
225
            Computationally more expensive but more robust in smaller dataets than \
226
            "spearman"
227
    colored : bool, optional
228
        If True the negative values in the correlation matrix are colored in red, by \
229
        default True
230
231
    Returns
232
    -------
233
    pd.DataFrame | pd.Styler
234
        If colored = True - corr: Pandas Styler object
235
        If colored = False - corr: Pandas DataFrame
236
    """
237
    # Validate Inputs
238
    _validate_input_range(threshold, "threshold", -1, 1)
239
    _validate_input_bool(colored, "colored")
240
241
    def color_negative_red(val: float) -> str:
242
        color = "#FF3344" if val < 0 else None
243
        return f"color: {color}"
244
245
    data = pd.DataFrame(data)
246
247
    _validate_input_num_data(data, "data")
248
249
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
250
        target_data = []
251
        if isinstance(target, str):
252
            target_data = data[target]
253
            data = data.drop(target, axis=1)
254
255
        elif isinstance(target, (list, pd.Series, np.ndarray)):
256
            target_data = pd.Series(target)
257
            target = target_data.name
258
259
        corr = pd.DataFrame(
260
            data.corrwith(target_data, method=method, numeric_only=True),
261
        )
262
        corr = corr.sort_values(corr.columns[0], ascending=False)
263
        corr.columns = [target]
264
265
    else:
266
        corr = data.corr(method=method, numeric_only=True)
267
268
    corr = _corr_selector(corr, split=split, threshold=threshold)
269
270
    if colored:
271
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
272
    return corr
273
274
275
def corr_plot(
276
    data: pd.DataFrame,
277
    split: Literal["pos", "neg", "high", "low"] | None = None,
278
    threshold: float = 0,
279
    target: pd.Series | str | None = None,
280
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
281
    cmap: str = "BrBG",
282
    figsize: tuple[float, float] = (12, 10),
283
    annot: bool = True,
284
    dev: bool = False,
285
    **kwargs,  # noqa: ANN003
286
) -> plt.Axes:
287
    """2D visualization of the correlation between feature-columns excluding NA values.
288
289
    Parameters
290
    ----------
291
    data : pd.DataFrame
292
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
293
        is provided, the index/column information is used to label the plots
294
    split : Optional[str], optional
295
        Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \
296
        None
297
            * None: visualize all correlations between the feature-columns
298
            * pos: visualize all positive correlations between the feature-columns \
299
                above the threshold
300
            * neg: visualize all negative correlations between the feature-columns \
301
                below the threshold
302
            * high: visualize all correlations between the feature-columns for \
303
                which abs (corr) > threshold is True
304
            * low: visualize all correlations between the feature-columns for which \
305
                abs(corr) < threshold is True
306
307
    threshold : float, optional
308
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
309
            split = "high" or split = "low", in which case default is 0.3
310
    target : Optional[pd.Series | str], optional
311
        Specify target for correlation. E.g. label column to generate only the \
312
        correlations between each feature and the label, by default None
313
    method : Literal['pearson', 'spearman', 'kendall'], optional
314
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
315
            * pearson: measures linear relationships and requires normally \
316
                distributed and homoscedastic data.
317
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
318
            * kendall: ranked/ordinal correlation, measures monotonic relationships. \
319
                Computationally more expensive but more robust in smaller dataets \
320
                than "spearman".
321
322
    cmap : str, optional
323
        The mapping from data values to color space, matplotlib colormap name or \
324
        object, or list of colors, by default "BrBG"
325
    figsize : tuple[float, float], optional
326
        Use to control the figure size, by default (12, 10)
327
    annot : bool, optional
328
        Use to show or hide annotations, by default True
329
    dev : bool, optional
330
        Display figure settings in the plot by setting dev = True. If False, the \
331
        settings are not displayed, by default False
332
333
    kwargs : optional
334
        Additional elements to control the visualization of the plot, e.g.:
335
336
            * mask: bool, default True
337
                If set to False the entire correlation matrix, including the upper \
338
                triangle is shown. Set dev = False in this case to avoid overlap.
339
            * vmax: float, default is calculated from the given correlation \
340
                coefficients.
341
                Value between -1 or vmin <= vmax <= 1, limits the range of the cbar.
342
            * vmin: float, default is calculated from the given correlation \
343
                coefficients.
344
                Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar.
345
            * linewidths: float, default 0.5
346
                Controls the line-width inbetween the squares.
347
            * annot_kws: dict, default {"size" : 10}
348
                Controls the font size of the annotations. Only available when \
349
                annot = True.
350
            * cbar_kws: dict, default {"shrink": .95, "aspect": 30}
351
                Controls the size of the colorbar.
352
            * Many more kwargs are available, i.e. "alpha" to control blending, or \
353
                options to adjust labels, ticks ...
354
355
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
356
357
    Returns
358
    -------
359
    ax: matplotlib Axes
360
        Returns the Axes object with the plot for further tweaking.
361
    """
362
    # Validate Inputs
363
    _validate_input_range(threshold, "threshold", -1, 1)
364
    _validate_input_bool(annot, "annot")
365
    _validate_input_bool(dev, "dev")
366
367
    data = pd.DataFrame(data)
368
369
    corr = corr_mat(
370
        data,
371
        split=split,
372
        threshold=threshold,
373
        target=target,
374
        method=method,
375
        colored=False,
376
    )
377
378
    mask = np.zeros_like(corr, dtype=bool)
379
380
    if target is None:
381
        mask = np.triu(np.ones_like(corr, dtype=bool))
382
383
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
384
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
385
386
    fig, ax = plt.subplots(figsize=figsize)
387
388
    # Specify kwargs for the heatmap
389
    kwargs = {
390
        "mask": mask,
391
        "cmap": cmap,
392
        "annot": annot,
393
        "vmax": vmax,
394
        "vmin": vmin,
395
        "linewidths": 0.5,
396
        "annot_kws": {"size": 10},
397
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
398
        **kwargs,
399
    }
400
401
    # Draw heatmap with mask and default settings
402
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
403
404
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
405
406
    # Settings
407
    if dev:
408
        fig.suptitle(
409
            f"\
410
            Settings (dev-mode): \n\
411
            - split-mode: {split} \n\
412
            - threshold: {threshold} \n\
413
            - method: {method} \n\
414
            - annotations: {annot} \n\
415
            - cbar: \n\
416
                - vmax: {vmax} \n\
417
                - vmin: {vmin} \n\
418
            - linewidths: {kwargs['linewidths']} \n\
419
            - annot_kws: {kwargs['annot_kws']} \n\
420
            - cbar_kws: {kwargs['cbar_kws']}",
421
            fontsize=12,
422
            color="gray",
423
            x=0.35,
424
            y=0.85,
425
            ha="left",
426
        )
427
428
    return ax
429
430
def corr_interactive_plot(
431
    data: pd.DataFrame,
432
    split: Literal["pos", "neg", "high", "low"] | None = None,
433
    threshold: float = 0,
434
    target: pd.Series | str | None = None,
435
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
436
    cmap: str = "BrBG",
437
    figsize: tuple[float, float] = (12, 10),
438
    annot: bool = True,
439
    **kwargs,
440
) -> go.Figure:
441
    """
442
    Two-dimensional visualization of the correlation between
443
    feature-columns using Plotly's Heatmap.
444
445
    Parameters
446
    ----------
447
    data : pd.DataFrame
448
        2D dataset that can be coerced into a Pandas DataFrame. If a
449
        Pandas DataFrame is provided, the index/column information is
450
        used to label the plots.
451
452
    split : Optional[str], optional
453
        Type of split to be performed
454
        {None, "pos", "neg", "high", "low"}, by default None
455
456
        - None: visualize all correlations between the feature-columns
457
458
        - pos: visualize all positive correlations between the
459
            feature-columns above the threshold
460
461
        - neg: visualize all negative correlations between the
462
            feature-columns below the threshold
463
464
        - high: visualize all correlations between the
465
            feature-columns for which abs(corr) > threshold is True
466
467
        - low: visualize all correlations between the
468
            feature-columns for which abs(corr) < threshold is True
469
470
    threshold : float, optional
471
        Value between 0 and 1 to set the correlation threshold,
472
        by default 0 unless split = "high" or split = "low", in
473
        which case the default is 0.3
474
475
    target : Optional[pd.Series | str], optional
476
        Specify a target for correlation. For example, the label column
477
        to generate only the correlations between each feature and the
478
        label, by default None
479
480
    method : Literal['pearson', 'spearman', 'kendall'], optional
481
        Method for correlation calculation:
482
        {"pearson", "spearman", "kendall"}, by default "pearson"
483
484
        - pearson: measures linear relationships and requires normally
485
            distributed and homoscedastic data.
486
        - spearman: ranked/ordinal correlation, measures monotonic
487
            relationships.
488
        - kendall: ranked/ordinal correlation, measures monotonic
489
            relationships. Computationally more expensive but more
490
            robust in smaller datasets than "spearman".
491
492
    cmap : str, optional
493
        The mapping from data values to color space, plotly
494
        colormap name or object, or list of colors, by default "BrBG"
495
496
    figsize : tuple[float, float], optional
497
        Use to control the figure size, by default (12, 10)
498
499
    annot : bool, optional
500
        Use to show or hide annotations, by default True
501
502
    **kwargs : optional
503
        Additional elements to control the visualization of the plot.
504
            These additional arguments will be passed to the `go.Heatmap`
505
            function in Plotly.
506
507
        Specific kwargs used in this function:
508
509
        - colorscale: str or list, optional
510
            The colorscale to be used for the heatmap. It controls the
511
            mapping of data values to colors in the heatmap.
512
513
        - zmax: float, optional
514
            The maximum value of the color scale. It limits the upper
515
            range of the colorbar displayed on the heatmap.
516
517
        - zmin: float, optional
518
            The minimum value of the color scale. It limits the lower
519
            range of the colorbar displayed on the heatmap.
520
521
        - text: pd.DataFrame, optional
522
            A DataFrame containing text to display on the heatmap. This
523
            text will be shown on the heatmap cells corresponding to the
524
            correlation values.
525
526
        - texttemplate: str, optional
527
            A text template string to format the text display on the
528
            heatmap. This allows you to customize how the text appears,
529
            including the display of the correlation values.
530
531
        - textfont: dict, optional
532
            A dictionary specifying the font properties for the text on
533
            the heatmap. You can customize the font size, color, family,
534
            etc., for the text annotations.
535
536
        - x: list, optional
537
            The list of column names for the x-axis of the heatmap. It
538
            allows you to customize the labels displayed on the x-axis.
539
540
        - y: list, optional
541
            The list of row names for the y-axis of the heatmap. It
542
            allows you to customize the labels displayed on the y-axis.
543
544
        - z: pd.DataFrame, optional
545
            The 2D array representing the correlation matrix to be
546
            visualized. This is the core data for generating the heatmap,
547
            containing the correlation values.
548
549
        - Many more kwargs are available, e.g., "hovertemplate" to control
550
            the legend hover template, or options to adjust the borderwidth
551
            and opacity of the heatmap. For a comprehensive list of
552
            available kwargs, please refer to the Plotly Heatmap documentation.
553
554
        Kwargs can be supplied through a dictionary of key-value pairs
555
        (see above) and can be found in Plotly Heatmap documentation.
556
557
    Returns
558
    -------
559
    heatmap : plotly.graph_objs._figure.Figure
560
        A Plotly Figure object representing the heatmap visualization of
561
        feature correlations.
562
    """
563
    # Validate Inputs
564
    _validate_input_range(threshold, "threshold", -1, 1)
565
    _validate_input_bool(annot, "annot")
566
567
    data = pd.DataFrame(data).iloc[:, ::-1]
568
569
    corr = corr_mat(
570
        data,
571
        split=split,
572
        threshold=threshold,
573
        target=target,
574
        method=method,
575
        colored=False,
576
    )
577
578
    mask = np.zeros_like(corr, dtype=bool)
579
580
    if target is None:
581
        mask = np.triu(np.ones_like(corr, dtype=bool))
582
        np.fill_diagonal(corr.to_numpy(), np.nan)
583
        corr = corr.where(mask == 1)
584
    else:
585
        corr = corr.iloc[::-1,:]
586
587
    vmax = np.round(np.nanmax(corr) - 0.05, 2)
588
    vmin = np.round(np.nanmin(corr) + 0.05, 2)
589
590
    if annot:
591
        vtext = corr.round(2).fillna("")
592
    else:
593
        vtext = None
594
595
    # Specify kwargs for the heatmap
596
    kwargs = {
597
        "colorscale": cmap,
598
        "zmax": vmax,
599
        "zmin": vmin,
600
        "text": vtext,
601
        "texttemplate": "%{text}",
602
        "textfont": {"size": 12},
603
        "x": corr.columns,
604
        "y": corr.index,
605
        "z": corr,
606
        **kwargs,
607
    }
608
609
    # Draw heatmap with masked corr and default settings
610
    heatmap = go.Figure(
611
        data=go.Heatmap(
612
            hoverongaps=False,
613
            xgap=1,
614
            ygap=1,
615
            **kwargs,
616
        )
617
    )
618
619
    for monitor in get_monitors():
620
        if monitor.is_primary:
621
            dpi = monitor.width / (monitor.width_mm / 25.4)
622
623
    if dpi is None:
0 ignored issues
show
introduced by
The variable dpi does not seem to be defined for all execution paths.
Loading history...
624
        try:
625
            monitor = get_monitors()[0]
626
            dpi = monitor.width / (monitor.width_mm / 25.4)
627
        except Exception as exc:
628
            raise LookupError("Monitor doesn't exist") from exc
629
630
    heatmap.update_layout(
631
        title=f"Feature-correlation ({method})",
632
        title_font={"size":24},
633
        title_x=0.5,
634
        autosize=True,
635
        width=figsize[0] * dpi,
636
        height=(figsize[1] + 1) * dpi,
637
        xaxis={"autorange": "reversed"},
638
    )
639
640
    return heatmap
641
642
643
def dist_plot(
644
    data: pd.DataFrame,
645
    mean_color: str = "orange",
646
    size: int = 3,
647
    fill_range: tuple = (0.025, 0.975),
648
    showall: bool = False,
649
    kde_kws: dict[str, Any] | None = None,
650
    rug_kws: dict[str, Any] | None = None,
651
    fill_kws: dict[str, Any] | None = None,
652
    font_kws: dict[str, Any] | None = None,
653
) -> None | Any:  # noqa: ANN401
654
    """2D visualization of the distribution of non binary numerical features.
655
656
    Parameters
657
    ----------
658
    data : pd.DataFrame
659
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
660
        is provided, the index/column information is used to label the plots
661
    mean_color : str, optional
662
        Color of the vertical line indicating the mean of the data, by default "orange"
663
    size : float, optional
664
        Controls the plot size, by default 3
665
    fill_range : tuple, optional
666
        Set the quantiles for shading. Default spans 95% of the data, which is about \
667
        two std. deviations above and below the mean, by default (0.025, 0.975)
668
    showall : bool, optional
669
        Set to True to remove the output limit of 20 plots, by default False
670
    kde_kws : dict[str, Any], optional
671
        Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \
672
        "linewidth": 1.5, "bw_adjust": 0.8}
673
    rug_kws : dict[str, Any], optional
674
        Keyword arguments for rugplot(), by default {"color": "#ff3333", \
675
        "alpha": 0.15, "lw": 3, "height": 0.075}
676
    fill_kws : dict[str, Any], optional
677
        Keyword arguments to control the fill, by default {"color": "#80d4ff", \
678
        "alpha": 0.2}
679
    font_kws : dict[str, Any], optional
680
        Keyword arguments to control the font, by default {"color":  "#111111", \
681
        "weight": "normal", "size": 11}
682
683
    Returns
684
    -------
685
    ax: matplotlib Axes
686
        Returns the Axes object with the plot for further tweaking.
687
    """
688
    # Validate Inputs
689
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
690
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
691
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
692
    _validate_input_bool(showall, "showall")
693
694
    # Handle dictionary defaults
695
    kde_kws = (
696
        {"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8}
697
        if kde_kws is None
698
        else kde_kws.copy()
699
    )
700
    rug_kws = (
701
        {"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075}
702
        if rug_kws is None
703
        else rug_kws.copy()
704
    )
705
    fill_kws = (
706
        {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy()
707
    )
708
    font_kws = (
709
        {"color": "#111111", "weight": "normal", "size": 11}
710
        if font_kws is None
711
        else font_kws.copy()
712
    )
713
714
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
715
    df = data.copy()  # noqa: PD901
716
    data = data.loc[:, data.nunique() > 2]  # noqa: PLR2004
717
    if data.shape[0] > 10000:  # noqa: PLR2004
718
        data = data.sample(n=10000, random_state=408)
719
        print(
720
            "Large dataset detected, using 10000 random samples for the plots. Summary"
721
            " statistics are still based on the entire dataset.",
722
        )
723
    cols = list(data.select_dtypes(include=["number"]).columns)
724
    data = data[cols]
725
726
    if not cols:
727
        print("No columns with numeric data were detected.")
728
        return None
729
730
    if len(cols) >= 20 and not showall:  # noqa: PLR2004
731
        print(
732
            "Note: The number of non binary numerical features is very large "
733
            f"({len(cols)}), please consider splitting the data. Showing plots for "
734
            "the first 20 numerical features. Override this by setting showall=True.",
735
        )
736
        cols = cols[:20]
737
738
    for col in cols:
739
        col_data = data[col].dropna(axis=0)
740
        col_df = df[col].dropna(axis=0)
741
742
        g = sns.displot(
743
            col_data,
744
            kind="kde",
745
            rug=True,
746
            height=size,
747
            aspect=5,
748
            legend=False,
749
            rug_kws=rug_kws,
750
            **kde_kws,
751
        )
752
753
        # Vertical lines and fill
754
        x, y = g.axes[0, 0].lines[0].get_xydata().T
755
        g.axes[0, 0].fill_between(
756
            x,
757
            y,
758
            where=(
759
                (x >= np.quantile(col_df, fill_range[0]))
760
                & (x <= np.quantile(col_df, fill_range[1]))
761
            ),
762
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
763
            **fill_kws,
764
        )
765
766
        mean = np.mean(col_df)
767
        std = scipy.stats.tstd(col_df)
768
        g.axes[0, 0].vlines(
769
            x=mean,
770
            ymin=0,
771
            ymax=np.interp(mean, x, y),
772
            ls="dotted",
773
            color=mean_color,
774
            lw=2,
775
            label="mean",
776
        )
777
        g.axes[0, 0].vlines(
778
            x=np.median(col_df),
779
            ymin=0,
780
            ymax=np.interp(np.median(col_df), x, y),
781
            ls=":",
782
            color=".3",
783
            label="median",
784
        )
785
        g.axes[0, 0].vlines(
786
            x=[mean - std, mean + std],
787
            ymin=0,
788
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
789
            ls=":",
790
            color=".5",
791
            label="\u03BC \u00B1 \u03C3",
792
        )
793
794
        g.axes[0, 0].set_ylim(0)
795
        g.axes[0, 0].set_xlim(
796
            g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05,
797
            g.axes[0, 0].get_xlim()[1] * 1.03,
798
        )
799
800
        # Annotations and legend
801
        g.axes[0, 0].text(
802
            0.005,
803
            0.9,
804
            f"Mean: {mean:.2f}",
805
            fontdict=font_kws,
806
            transform=g.axes[0, 0].transAxes,
807
        )
808
        g.axes[0, 0].text(
809
            0.005,
810
            0.7,
811
            f"Std. dev: {std:.2f}",
812
            fontdict=font_kws,
813
            transform=g.axes[0, 0].transAxes,
814
        )
815
        g.axes[0, 0].text(
816
            0.005,
817
            0.5,
818
            f"Skew: {scipy.stats.skew(col_df):.2f}",
819
            fontdict=font_kws,
820
            transform=g.axes[0, 0].transAxes,
821
        )
822
        g.axes[0, 0].text(
823
            0.005,
824
            0.3,
825
            f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}",  # Excess Kurtosis
826
            fontdict=font_kws,
827
            transform=g.axes[0, 0].transAxes,
828
        )
829
        g.axes[0, 0].text(
830
            0.005,
831
            0.1,
832
            f"Count: {len(col_df)}",
833
            fontdict=font_kws,
834
            transform=g.axes[0, 0].transAxes,
835
        )
836
        g.axes[0, 0].legend(loc="upper right")
837
838
        return g.axes[0, 0]
839
    return None
840
841
842
def missingval_plot(  # noqa: C901, PLR0915
843
    data: pd.DataFrame,
844
    cmap: str = "PuBuGn",
845
    figsize: tuple = (20, 20),
846
    sort: bool = False,
847
    spine_color: str = "#EEEEEE",
848
) -> GridSpec:
849
    """Two-dimensional visualization of the missing values in a dataset.
850
851
    Parameters
852
    ----------
853
    data : pd.DataFrame
854
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
855
        is provided, the index/column information is used to label the plots
856
    cmap : str, optional
857
        Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \
858
        found in the matplotlib documentation, by default "PuBuGn"
859
    figsize : tuple, optional
860
        Use to control the figure size, by default (20, 20)
861
    sort : bool, optional
862
        Sort columns based on missing values in descending order and drop columns \
863
        without any missing values, by default False
864
    spine_color : str, optional
865
        Set to "None" to hide the spines on all plots or use any valid matplotlib \
866
        color argument, by default "#EEEEEE"
867
868
    Returns
869
    -------
870
    GridSpec
871
        gs: Figure with array of Axes objects
872
    """
873
    # Validate Inputs
874
    _validate_input_bool(sort, "sort")
875
876
    data = pd.DataFrame(data)
877
878
    if sort:
879
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
880
        final_cols = (
881
            mv_cols_sorted.drop(
882
                mv_cols_sorted[mv_cols_sorted.to_numpy() == 0].keys().tolist(),
883
            )
884
            .keys()
885
            .tolist()
886
        )
887
        data = data[final_cols]
888
        print("Displaying only columns with missing values.")
889
890
    # Identify missing values
891
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
892
    total_datapoints = data.shape[0] * data.shape[1]
893
894
    if mv_total == 0:
895
        print("No missing values found in the dataset.")
896
        return None
897
898
    # Create figure and axes
899
    fig = plt.figure(figsize=figsize)
900
    gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
901
    ax1 = fig.add_subplot(gs[:1, :5])
902
    ax2 = fig.add_subplot(gs[1:, :5])
903
    ax3 = fig.add_subplot(gs[:1, 5:])
904
    ax4 = fig.add_subplot(gs[1:, 5:])
905
906
    # ax1 - Barplot
907
    colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
908
    ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
909
    ax1.get_xaxis().set_visible(False)  # noqa: FBT003
910
    ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
911
    ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
912
    ax1.grid(linestyle=":", linewidth=1)
913
    ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=1))
914
    ax1.tick_params(axis="y", colors="#111111", length=1)
915
916
    # annotate values on top of the bars
917
    for rect, label in zip(ax1.patches, mv_cols, strict=True):
918
        height = rect.get_height()
919
        ax1.text(
920
            rect.get_x() + rect.get_width() / 2,
921
            height + max(np.log(1 + height / 6), 0.075),
922
            label,
923
            ha="center",
924
            va="bottom",
925
            rotation=90,
926
            alpha=0.5,
927
            fontsize="11",
928
        )
929
930
    ax1.set_frame_on(True)  # noqa: FBT003
931
    for _, spine in ax1.spines.items():
932
        spine.set_visible(True)  # noqa: FBT003
933
        spine.set_color(spine_color)
934
    ax1.spines["top"].set_color(None)
935
936
    # ax2 - Heatmap
937
    sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
938
    ax2.set_yticks(np.round(ax2.get_yticks()[::5], -1))
939
    ax2.set_yticklabels(ax2.get_yticks())
940
    ax2.set_xticklabels(
941
        ax2.get_xticklabels(),
942
        horizontalalignment="center",
943
        fontweight="light",
944
        fontsize="12",
945
    )
946
    ax2.tick_params(length=1, colors="#111111")
947
    for _, spine in ax2.spines.items():
948
        spine.set_visible(True)  # noqa: FBT003
949
        spine.set_color(spine_color)
950
951
    # ax3 - Summary
952
    fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
953
    ax3.get_xaxis().set_visible(False)  # noqa: FBT003
954
    ax3.get_yaxis().set_visible(False)  # noqa: FBT003
955
    ax3.set(frame_on=False)
956
957
    ax3.text(
958
        0.025,
959
        0.875,
960
        f"Total: {np.round(total_datapoints/1000,1)}K",
961
        transform=ax3.transAxes,
962
        fontdict=fontax3,
963
    )
964
    ax3.text(
965
        0.025,
966
        0.675,
967
        f"Missing: {np.round(mv_total/1000,1)}K",
968
        transform=ax3.transAxes,
969
        fontdict=fontax3,
970
    )
971
    ax3.text(
972
        0.025,
973
        0.475,
974
        f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
975
        transform=ax3.transAxes,
976
        fontdict=fontax3,
977
    )
978
    ax3.text(
979
        0.025,
980
        0.275,
981
        f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
982
        transform=ax3.transAxes,
983
        fontdict=fontax3,
984
    )
985
    ax3.text(
986
        0.025,
987
        0.075,
988
        f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
989
        transform=ax3.transAxes,
990
        fontdict=fontax3,
991
    )
992
993
    # ax4 - Scatter plot
994
    ax4.get_yaxis().set_visible(False)  # noqa: FBT003
995
    for _, spine in ax4.spines.items():
996
        spine.set_color(spine_color)
997
    ax4.tick_params(axis="x", colors="#111111", length=1)
998
999
    ax4.scatter(
1000
        mv_rows,
1001
        range(len(mv_rows)),
1002
        s=mv_rows,
1003
        c=mv_rows,
1004
        cmap=cmap,
1005
        marker=".",
1006
        vmin=1,
1007
    )
1008
    ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
1009
    ax4.set_xlim(0, max(mv_rows) + 0.5)
1010
    ax4.grid(linestyle=":", linewidth=1)
1011
1012
    gs.figure.suptitle(
1013
        "Missing value plot",
1014
        x=0.45,
1015
        y=0.94,
1016
        fontsize=18,
1017
        color="#111111",
1018
    )
1019
1020
    return gs
1021