GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Test Failed
Push — main ( 5bb691...d8c023 )
by Andreas
01:49
created

klib.describe.corr_interactive_plot()   D

Complexity

Conditions 11

Size

Total Lines 214
Code Lines 71

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 71
nop 9
dl 0
loc 214
rs 4.729
c 0
b 0
f 0

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like klib.describe.corr_interactive_plot() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for descriptive analytics.
2
3
:author: Andreas Kanz
4
5
"""
6
from __future__ import annotations
7
8
from typing import Any
9
from typing import Literal
10
11
import matplotlib.pyplot as plt
12
import numpy as np
13
import pandas as pd
14
import plotly.graph_objects as go
15
import scipy
16
import seaborn as sns
17
from matplotlib import ticker
18
from matplotlib.colors import LinearSegmentedColormap
19
from matplotlib.colors import to_rgb
20
from matplotlib.gridspec import GridSpec  # noqa: TCH002
21
from screeninfo import get_monitors
22
23
from klib.utils import _corr_selector
24
from klib.utils import _missing_vals
25
from klib.utils import _validate_input_bool
26
from klib.utils import _validate_input_int
27
from klib.utils import _validate_input_num_data
28
from klib.utils import _validate_input_range
29
from klib.utils import _validate_input_smaller
30
from klib.utils import _validate_input_sum_larger
31
32
__all__ = [
33
    "cat_plot",
34
    "corr_interactive_plot",
35
    "corr_mat",
36
    "corr_plot",
37
    "dist_plot",
38
    "missingval_plot",
39
]
40
41
42
def cat_plot(  # noqa: C901, PLR0915
43
    data: pd.DataFrame,
44
    figsize: tuple[float, float] = (18, 18),
45
    top: int = 3,
46
    bottom: int = 3,
47
    bar_color_top: str = "#5ab4ac",
48
    bar_color_bottom: str = "#d8b365",
49
) -> GridSpec:
50
    """Two-dimensional visualization of number and frequency of categorical features.
51
52
    Parameters
53
    ----------
54
    data : pd.DataFrame
55
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
56
        is provided, the index/column information is used to label the plots
57
    figsize : tuple[float, float], optional
58
        Use to control the figure size, by default (18, 18)
59
    top : int, optional
60
        Show the "top" most frequent values in a column, by default 3
61
    bottom : int, optional
62
        Show the "bottom" most frequent values in a column, by default 3
63
    bar_color_top : str, optional
64
        Use to control the color of the bars indicating the most common values, by \
65
        default "#5ab4ac"
66
    bar_color_bottom : str, optional
67
        Use to control the color of the bars indicating the least common values, by \
68
        default "#d8b365"
69
70
    Returns
71
    -------
72
    Gridspec
73
        gs: Figure with array of Axes objects
74
    """
75
    # Validate Inputs
76
    _validate_input_int(top, "top")
77
    _validate_input_int(bottom, "bottom")
78
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
79
80
    data = pd.DataFrame(data).copy()
81
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
82
    data = data[cols]
83
84
    if len(cols) == 0:
85
        print("No columns with categorical data were detected.")
86
        return None
87
88
    for col in data.columns:
89
        if data[col].dtype.name in ("category", "string"):
90
            data[col] = data[col].astype("object")
91
92
    fig = plt.figure(figsize=figsize)
93
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
94
95
    for count, col in enumerate(cols):
96
        n_unique = data[col].nunique(dropna=True)
97
        value_counts = data[col].value_counts()
98
        lim_top, lim_bot = top, bottom
99
100
        if n_unique < top + bottom:
101
            if bottom > top:
102
                lim_top = min(int(n_unique // 2), top)
103
                lim_bot = n_unique - lim_top
104
            else:
105
                lim_bot = min(int(n_unique // 2), bottom)
106
                lim_top = n_unique - lim_bot
107
108
        value_counts_top = value_counts[:lim_top]
109
        value_counts_idx_top = value_counts_top.index.tolist()
110
        value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame()
111
        value_counts_idx_bot = value_counts_bot.index.tolist()
112
113
        if top == 0:
114
            value_counts_top = value_counts_idx_top = []
115
116
        if bottom == 0:
117
            value_counts_bot = value_counts_idx_bot = []
118
119
        data.loc[data[col].isin(value_counts_idx_top), col] = 10
120
        data.loc[data[col].isin(value_counts_idx_bot), col] = 0
121
        data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5  # noqa: PLR2004
122
        data[col] = data[col].rolling(2, min_periods=1).mean()
123
124
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
125
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
126
        sum_top = sum(value_counts_top)
127
        sum_bot = sum(value_counts_bot)
128
129
        # Barcharts
130
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
131
        ax_top.bar(
132
            value_counts_idx_top,
133
            value_counts_top,
134
            color=bar_color_top,
135
            width=0.85,
136
        )
137
        ax_top.bar(
138
            value_counts_idx_bot,
139
            value_counts_bot,
140
            color=bar_color_bottom,
141
            width=0.85,
142
        )
143
        ax_top.set(frame_on=False)
144
        ax_top.tick_params(axis="x", labelrotation=90)
145
146
        # Summary stats
147
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
148
        plt.subplots_adjust(hspace=0.075)
149
        ax_bottom.get_yaxis().set_visible(False)  # noqa: FBT003
150
        ax_bottom.get_xaxis().set_visible(False)  # noqa: FBT003
151
        ax_bottom.set(frame_on=False)
152
        ax_bottom.text(
153
            0,
154
            0,
155
            f"Unique values: {n_unique}\n\n"
156
            f"Top {lim_top}: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n"
157
            f"Bot {lim_bot}: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)",
158
            transform=ax_bottom.transAxes,
159
            color="#111111",
160
            fontsize=11,
161
        )
162
163
    # Heatmap
164
    color_bot_rgb = to_rgb(bar_color_bottom)
165
    color_white = to_rgb("#FFFFFF")
166
    color_top_rgb = to_rgb(bar_color_top)
167
    cat_plot_cmap = LinearSegmentedColormap.from_list(
168
        "cat_plot_cmap",
169
        [color_bot_rgb, color_white, color_top_rgb],
170
        N=200,
171
    )
172
    ax_hm = fig.add_subplot(gs[2:, :])
173
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm)
174
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[::5], -1))
175
    ax_hm.set_yticklabels(ax_hm.get_yticks())
176
    ax_hm.set_xticklabels(
177
        ax_hm.get_xticklabels(),
178
        horizontalalignment="center",
179
        fontweight="light",
180
        fontsize="medium",
181
    )
182
    ax_hm.tick_params(length=1, colors="#111111")
183
    gs.figure.suptitle(
184
        "Categorical data plot",
185
        x=0.5,
186
        y=0.91,
187
        fontsize=18,
188
        color="#111111",
189
    )
190
191
    return gs
192
193
194
def corr_mat(
195
    data: pd.DataFrame,
196
    split: Literal["pos", "neg", "high", "low"] | None = None,
197
    threshold: float = 0,
198
    target: pd.DataFrame | pd.Series | np.ndarray | str | None = None,
199
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
200
    colored: bool = True,
201
) -> pd.DataFrame | pd.Series:
202
    """Return a color-encoded correlation matrix.
203
204
    Parameters
205
    ----------
206
    data : pd.DataFrame
207
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
208
        is provided, the index/column information is used to label the plots
209
    split : Optional[Literal['pos', 'neg', 'high', 'low']], optional
210
        Type of split to be performed, by default None
211
        {None, "pos", "neg", "high", "low"}
212
    threshold : float, optional
213
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
214
        split = "high" or split = "low", in which case default is 0.3
215
    target : Optional[pd.DataFrame | str], optional
216
        Specify target for correlation. E.g. label column to generate only the \
217
        correlations between each feature and the label, by default None
218
    method : Literal['pearson', 'spearman', 'kendall'], optional
219
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
220
        * pearson: measures linear relationships and requires normally distributed \
221
            and homoscedastic data.
222
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
223
        * kendall: ranked/ordinal correlation, measures monotonic relationships. \
224
            Computationally more expensive but more robust in smaller dataets than \
225
            "spearman"
226
    colored : bool, optional
227
        If True the negative values in the correlation matrix are colored in red, by \
228
        default True
229
230
    Returns
231
    -------
232
    pd.DataFrame | pd.Styler
233
        If colored = True - corr: Pandas Styler object
234
        If colored = False - corr: Pandas DataFrame
235
    """
236
    # Validate Inputs
237
    _validate_input_range(threshold, "threshold", -1, 1)
238
    _validate_input_bool(colored, "colored")
239
240
    def color_negative_red(val: float) -> str:
241
        color = "#FF3344" if val < 0 else None
242
        return f"color: {color}"
243
244
    data = pd.DataFrame(data)
245
246
    _validate_input_num_data(data, "data")
247
248
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
249
        target_data = []
250
        if isinstance(target, str):
251
            target_data = data[target]
252
            data = data.drop(target, axis=1)
253
254
        elif isinstance(target, (list, pd.Series, np.ndarray)):
255
            target_data = pd.Series(target)
256
            target = target_data.name
257
258
        corr = pd.DataFrame(
259
            data.corrwith(target_data, method=method, numeric_only=True),
260
        )
261
        corr = corr.sort_values(corr.columns[0], ascending=False)
262
        corr.columns = [target]
263
264
    else:
265
        corr = data.corr(method=method, numeric_only=True)
266
267
    corr = _corr_selector(corr, split=split, threshold=threshold)
268
269
    if colored:
270
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
271
    return corr
272
273
274
def corr_plot(
275
    data: pd.DataFrame,
276
    split: Literal["pos", "neg", "high", "low"] | None = None,
277
    threshold: float = 0,
278
    target: pd.Series | str | None = None,
279
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
280
    cmap: str = "BrBG",
281
    figsize: tuple[float, float] = (12, 10),
282
    annot: bool = True,
283
    dev: bool = False,
284
    **kwargs,  # noqa: ANN003
285
) -> plt.Axes:
286
    """2D visualization of the correlation between feature-columns excluding NA values.
287
288
    Parameters
289
    ----------
290
    data : pd.DataFrame
291
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
292
        is provided, the index/column information is used to label the plots
293
    split : Optional[str], optional
294
        Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \
295
        None
296
            * None: visualize all correlations between the feature-columns
297
            * pos: visualize all positive correlations between the feature-columns \
298
                above the threshold
299
            * neg: visualize all negative correlations between the feature-columns \
300
                below the threshold
301
            * high: visualize all correlations between the feature-columns for \
302
                which abs (corr) > threshold is True
303
            * low: visualize all correlations between the feature-columns for which \
304
                abs(corr) < threshold is True
305
306
    threshold : float, optional
307
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
308
            split = "high" or split = "low", in which case default is 0.3
309
    target : Optional[pd.Series | str], optional
310
        Specify target for correlation. E.g. label column to generate only the \
311
        correlations between each feature and the label, by default None
312
    method : Literal['pearson', 'spearman', 'kendall'], optional
313
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
314
            * pearson: measures linear relationships and requires normally \
315
                distributed and homoscedastic data.
316
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
317
            * kendall: ranked/ordinal correlation, measures monotonic relationships. \
318
                Computationally more expensive but more robust in smaller dataets \
319
                than "spearman".
320
321
    cmap : str, optional
322
        The mapping from data values to color space, matplotlib colormap name or \
323
        object, or list of colors, by default "BrBG"
324
    figsize : tuple[float, float], optional
325
        Use to control the figure size, by default (12, 10)
326
    annot : bool, optional
327
        Use to show or hide annotations, by default True
328
    dev : bool, optional
329
        Display figure settings in the plot by setting dev = True. If False, the \
330
        settings are not displayed, by default False
331
332
    kwargs : optional
333
        Additional elements to control the visualization of the plot, e.g.:
334
335
            * mask: bool, default True
336
                If set to False the entire correlation matrix, including the upper \
337
                triangle is shown. Set dev = False in this case to avoid overlap.
338
            * vmax: float, default is calculated from the given correlation \
339
                coefficients.
340
                Value between -1 or vmin <= vmax <= 1, limits the range of the cbar.
341
            * vmin: float, default is calculated from the given correlation \
342
                coefficients.
343
                Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar.
344
            * linewidths: float, default 0.5
345
                Controls the line-width inbetween the squares.
346
            * annot_kws: dict, default {"size" : 10}
347
                Controls the font size of the annotations. Only available when \
348
                annot = True.
349
            * cbar_kws: dict, default {"shrink": .95, "aspect": 30}
350
                Controls the size of the colorbar.
351
            * Many more kwargs are available, i.e. "alpha" to control blending, or \
352
                options to adjust labels, ticks ...
353
354
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
355
356
    Returns
357
    -------
358
    ax: matplotlib Axes
359
        Returns the Axes object with the plot for further tweaking.
360
    """
361
    # Validate Inputs
362
    _validate_input_range(threshold, "threshold", -1, 1)
363
    _validate_input_bool(annot, "annot")
364
    _validate_input_bool(dev, "dev")
365
366
    data = pd.DataFrame(data)
367
368
    corr = corr_mat(
369
        data,
370
        split=split,
371
        threshold=threshold,
372
        target=target,
373
        method=method,
374
        colored=False,
375
    )
376
377
    mask = np.zeros_like(corr, dtype=bool)
378
379
    if target is None:
380
        mask = np.triu(np.ones_like(corr, dtype=bool))
381
382
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
383
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
384
385
    fig, ax = plt.subplots(figsize=figsize)
386
387
    # Specify kwargs for the heatmap
388
    kwargs = {
389
        "mask": mask,
390
        "cmap": cmap,
391
        "annot": annot,
392
        "vmax": vmax,
393
        "vmin": vmin,
394
        "linewidths": 0.5,
395
        "annot_kws": {"size": 10},
396
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
397
        **kwargs,
398
    }
399
400
    # Draw heatmap with mask and default settings
401
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
402
403
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
404
405
    # Settings
406
    if dev:
407
        fig.suptitle(
408
            f"\
409
            Settings (dev-mode): \n\
410
            - split-mode: {split} \n\
411
            - threshold: {threshold} \n\
412
            - method: {method} \n\
413
            - annotations: {annot} \n\
414
            - cbar: \n\
415
                - vmax: {vmax} \n\
416
                - vmin: {vmin} \n\
417
            - linewidths: {kwargs['linewidths']} \n\
418
            - annot_kws: {kwargs['annot_kws']} \n\
419
            - cbar_kws: {kwargs['cbar_kws']}",
420
            fontsize=12,
421
            color="gray",
422
            x=0.35,
423
            y=0.85,
424
            ha="left",
425
        )
426
427
    return ax
428
429
430
def corr_interactive_plot(
431
    data: pd.DataFrame,
432
    split: Literal["pos", "neg", "high", "low"] | None = None,
433
    threshold: float = 0.0,
434
    target: pd.Series | str | None = None,
435
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
436
    cmap: str = "BrBG",
437
    figsize: tuple[float, float] = (12, 10),
438
    annot: bool = True,
439
    **kwargs,  # noqa: ANN003
440
) -> go.Figure:
441
    """Interactive 2D visualization of the correlation between feature-columns.
442
443
    Parameters
444
    ----------
445
    data : pd.DataFrame
446
        2D dataset that can be coerced into a Pandas DataFrame. If a
447
        Pandas DataFrame is provided, the index/column information is
448
        used to label the plots.
449
450
    split : Optional[str], optional
451
        Type of split to be performed
452
        {None, "pos", "neg", "high", "low"}, by default None
453
454
        - None: visualize all correlations between the feature-columns
455
456
        - pos: visualize all positive correlations between the
457
            feature-columns above the threshold
458
459
        - neg: visualize all negative correlations between the
460
            feature-columns below the threshold
461
462
        - high: visualize all correlations between the
463
            feature-columns for which abs(corr) > threshold is True
464
465
        - low: visualize all correlations between the
466
            feature-columns for which abs(corr) < threshold is True
467
468
    threshold : float, optional
469
        Value between 0 and 1 to set the correlation threshold,
470
        by default 0 unless split = "high" or split = "low", in
471
        which case the default is 0.3
472
473
    target : Optional[pd.Series | str], optional
474
        Specify a target for correlation. For example, the label column
475
        to generate only the correlations between each feature and the
476
        label, by default None
477
478
    method : Literal['pearson', 'spearman', 'kendall'], optional
479
        Method for correlation calculation:
480
        {"pearson", "spearman", "kendall"}, by default "pearson"
481
482
        - pearson: measures linear relationships and requires normally
483
            distributed and homoscedastic data.
484
        - spearman: ranked/ordinal correlation, measures monotonic
485
            relationships.
486
        - kendall: ranked/ordinal correlation, measures monotonic
487
            relationships. Computationally more expensive but more
488
            robust in smaller datasets than "spearman".
489
490
    cmap : str, optional
491
        The mapping from data values to color space, plotly
492
        colormap name or object, or list of colors, by default "BrBG"
493
494
    figsize : tuple[float, float], optional
495
        Use to control the figure size, by default (12, 10)
496
497
    annot : bool, optional
498
        Use to show or hide annotations, by default True
499
500
    **kwargs : optional
501
        Additional elements to control the visualization of the plot.
502
            These additional arguments will be passed to the `go.Heatmap`
503
            function in Plotly.
504
505
        Specific kwargs used in this function:
506
507
        - colorscale: str or list, optional
508
            The colorscale to be used for the heatmap. It controls the
509
            mapping of data values to colors in the heatmap.
510
511
        - zmax: float, optional
512
            The maximum value of the color scale. It limits the upper
513
            range of the colorbar displayed on the heatmap.
514
515
        - zmin: float, optional
516
            The minimum value of the color scale. It limits the lower
517
            range of the colorbar displayed on the heatmap.
518
519
        - text: pd.DataFrame, optional
520
            A DataFrame containing text to display on the heatmap. This
521
            text will be shown on the heatmap cells corresponding to the
522
            correlation values.
523
524
        - texttemplate: str, optional
525
            A text template string to format the text display on the
526
            heatmap. This allows you to customize how the text appears,
527
            including the display of the correlation values.
528
529
        - textfont: dict, optional
530
            A dictionary specifying the font properties for the text on
531
            the heatmap. You can customize the font size, color, family,
532
            etc., for the text annotations.
533
534
        - x: list, optional
535
            The list of column names for the x-axis of the heatmap. It
536
            allows you to customize the labels displayed on the x-axis.
537
538
        - y: list, optional
539
            The list of row names for the y-axis of the heatmap. It
540
            allows you to customize the labels displayed on the y-axis.
541
542
        - z: pd.DataFrame, optional
543
            The 2D array representing the correlation matrix to be
544
            visualized. This is the core data for generating the heatmap,
545
            containing the correlation values.
546
547
        - Many more kwargs are available, e.g., "hovertemplate" to control
548
            the legend hover template, or options to adjust the borderwidth
549
            and opacity of the heatmap. For a comprehensive list of
550
            available kwargs, please refer to the Plotly Heatmap documentation.
551
552
        Kwargs can be supplied through a dictionary of key-value pairs
553
        (see above) and can be found in Plotly Heatmap documentation.
554
555
    Returns
556
    -------
557
    heatmap : plotly.graph_objs._figure.Figure
558
        A Plotly Figure object representing the heatmap visualization of
559
        feature correlations.
560
    """
561
    # Validate Inputs
562
    _validate_input_range(threshold, "threshold", -1, 1)
563
    _validate_input_bool(annot, "annot")
564
565
    data = pd.DataFrame(data).iloc[:, ::-1]
566
567
    corr = corr_mat(
568
        data,
569
        split=split,
570
        threshold=threshold,
571
        target=target,
572
        method=method,
573
        colored=False,
574
    )
575
576
    mask = np.zeros_like(corr, dtype=bool)
577
578
    if target is None:
579
        mask = np.triu(np.ones_like(corr, dtype=bool))
580
        np.fill_diagonal(corr.to_numpy(), np.nan)
581
        corr = corr.where(mask == 1)
582
    else:
583
        corr = corr.iloc[::-1, :]
584
585
    vmax = np.round(np.nanmax(corr) - 0.05, 2)
586
    vmin = np.round(np.nanmin(corr) + 0.05, 2)
587
588
    vmax = -vmin if split == "neg" else vmax
589
    vmin = -vmax if split == "pos" else vmin
590
591
    vtext = corr.round(2).fillna("") if annot else None
592
593
    # Specify kwargs for the heatmap
594
    kwargs = {
595
        "colorscale": cmap,
596
        "zmax": vmax,
597
        "zmin": vmin,
598
        "text": vtext,
599
        "texttemplate": "%{text}",
600
        "textfont": {"size": 12},
601
        "x": corr.columns,
602
        "y": corr.index,
603
        "z": corr,
604
        **kwargs,
605
    }
606
607
    # Draw heatmap with masked corr and default settings
608
    heatmap = go.Figure(
609
        data=go.Heatmap(
610
            hoverongaps=False,
611
            xgap=1,
612
            ygap=1,
613
            **kwargs,
614
        ),
615
    )
616
617
    dpi = 96  # more or less arbitrary default value
618
    for monitor in get_monitors():
619
        if monitor.is_primary:
620
            if monitor.width_mm is None or monitor.height_mm is None:
621
                continue
622
            dpi = monitor.width / (monitor.width_mm / 25.4)
623
            break
624
625
    if dpi is None:
626
        try:
627
            monitor = get_monitors()[0]
628
            dpi = monitor.width / (monitor.width_mm / 25.4)
629
        except ValueError as exc:
630
            msg = "Monitor doesn't exist"
631
            raise LookupError(msg) from exc
632
633
    heatmap.update_layout(
634
        title=f"Feature-correlation ({method})",
635
        title_font={"size": 24},
636
        title_x=0.5,
637
        autosize=True,
638
        width=figsize[0] * dpi,
639
        height=(figsize[1] + 1) * dpi,
640
        xaxis={"autorange": "reversed"},
641
    )
642
643
    return heatmap
644
645
646
def dist_plot(
647
    data: pd.DataFrame,
648
    mean_color: str = "orange",
649
    size: int = 3,
650
    fill_range: tuple = (0.025, 0.975),
651
    showall: bool = False,
652
    kde_kws: dict[str, Any] | None = None,
653
    rug_kws: dict[str, Any] | None = None,
654
    fill_kws: dict[str, Any] | None = None,
655
    font_kws: dict[str, Any] | None = None,
656
) -> None | Any:  # noqa: ANN401
657
    """2D visualization of the distribution of non binary numerical features.
658
659
    Parameters
660
    ----------
661
    data : pd.DataFrame
662
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
663
        is provided, the index/column information is used to label the plots
664
    mean_color : str, optional
665
        Color of the vertical line indicating the mean of the data, by default "orange"
666
    size : float, optional
667
        Controls the plot size, by default 3
668
    fill_range : tuple, optional
669
        Set the quantiles for shading. Default spans 95% of the data, which is about \
670
        two std. deviations above and below the mean, by default (0.025, 0.975)
671
    showall : bool, optional
672
        Set to True to remove the output limit of 20 plots, by default False
673
    kde_kws : dict[str, Any], optional
674
        Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \
675
        "linewidth": 1.5, "bw_adjust": 0.8}
676
    rug_kws : dict[str, Any], optional
677
        Keyword arguments for rugplot(), by default {"color": "#ff3333", \
678
        "alpha": 0.15, "lw": 3, "height": 0.075}
679
    fill_kws : dict[str, Any], optional
680
        Keyword arguments to control the fill, by default {"color": "#80d4ff", \
681
        "alpha": 0.2}
682
    font_kws : dict[str, Any], optional
683
        Keyword arguments to control the font, by default {"color":  "#111111", \
684
        "weight": "normal", "size": 11}
685
686
    Returns
687
    -------
688
    ax: matplotlib Axes
689
        Returns the Axes object with the plot for further tweaking.
690
    """
691
    # Validate Inputs
692
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
693
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
694
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
695
    _validate_input_bool(showall, "showall")
696
697
    # Handle dictionary defaults
698
    kde_kws = (
699
        {"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8}
700
        if kde_kws is None
701
        else kde_kws.copy()
702
    )
703
    rug_kws = (
704
        {"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075}
705
        if rug_kws is None
706
        else rug_kws.copy()
707
    )
708
    fill_kws = (
709
        {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy()
710
    )
711
    font_kws = (
712
        {"color": "#111111", "weight": "normal", "size": 11}
713
        if font_kws is None
714
        else font_kws.copy()
715
    )
716
717
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
718
    df = data.copy()  # noqa: PD901
719
    data = data.loc[:, data.nunique() > 2]  # noqa: PLR2004
720
    if data.shape[0] > 10000:  # noqa: PLR2004
721
        data = data.sample(n=10000, random_state=408)
722
        print(
723
            "Large dataset detected, using 10000 random samples for the plots. Summary"
724
            " statistics are still based on the entire dataset.",
725
        )
726
    cols = list(data.select_dtypes(include=["number"]).columns)
727
    data = data[cols]
728
729
    if not cols:
730
        print("No columns with numeric data were detected.")
731
        return None
732
733
    if len(cols) >= 20 and not showall:  # noqa: PLR2004
734
        print(
735
            "Note: The number of non binary numerical features is very large "
736
            f"({len(cols)}), please consider splitting the data. Showing plots for "
737
            "the first 20 numerical features. Override this by setting showall=True.",
738
        )
739
        cols = cols[:20]
740
741
    for col in cols:
742
        col_data = data[col].dropna(axis=0)
743
        col_df = df[col].dropna(axis=0)
744
745
        g = sns.displot(
746
            col_data,
747
            kind="kde",
748
            rug=True,
749
            height=size,
750
            aspect=5,
751
            legend=False,
752
            rug_kws=rug_kws,
753
            **kde_kws,
754
        )
755
756
        # Vertical lines and fill
757
        x, y = g.axes[0, 0].lines[0].get_xydata().T
758
        g.axes[0, 0].fill_between(
759
            x,
760
            y,
761
            where=(
762
                (x >= np.quantile(col_df, fill_range[0]))
763
                & (x <= np.quantile(col_df, fill_range[1]))
764
            ),
765
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
766
            **fill_kws,
767
        )
768
769
        mean = np.mean(col_df)
770
        std = scipy.stats.tstd(col_df)
771
        g.axes[0, 0].vlines(
772
            x=mean,
773
            ymin=0,
774
            ymax=np.interp(mean, x, y),
775
            ls="dotted",
776
            color=mean_color,
777
            lw=2,
778
            label="mean",
779
        )
780
        g.axes[0, 0].vlines(
781
            x=np.median(col_df),
782
            ymin=0,
783
            ymax=np.interp(np.median(col_df), x, y),
784
            ls=":",
785
            color=".3",
786
            label="median",
787
        )
788
        g.axes[0, 0].vlines(
789
            x=[mean - std, mean + std],
790
            ymin=0,
791
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
792
            ls=":",
793
            color=".5",
794
            label="\u03BC \u00B1 \u03C3",
795
        )
796
797
        g.axes[0, 0].set_ylim(0)
798
        g.axes[0, 0].set_xlim(
799
            g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05,
800
            g.axes[0, 0].get_xlim()[1] * 1.03,
801
        )
802
803
        # Annotations and legend
804
        g.axes[0, 0].text(
805
            0.005,
806
            0.9,
807
            f"Mean: {mean:.2f}",
808
            fontdict=font_kws,
809
            transform=g.axes[0, 0].transAxes,
810
        )
811
        g.axes[0, 0].text(
812
            0.005,
813
            0.7,
814
            f"Std. dev: {std:.2f}",
815
            fontdict=font_kws,
816
            transform=g.axes[0, 0].transAxes,
817
        )
818
        g.axes[0, 0].text(
819
            0.005,
820
            0.5,
821
            f"Skew: {scipy.stats.skew(col_df):.2f}",
822
            fontdict=font_kws,
823
            transform=g.axes[0, 0].transAxes,
824
        )
825
        g.axes[0, 0].text(
826
            0.005,
827
            0.3,
828
            f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}",  # Excess Kurtosis
829
            fontdict=font_kws,
830
            transform=g.axes[0, 0].transAxes,
831
        )
832
        g.axes[0, 0].text(
833
            0.005,
834
            0.1,
835
            f"Count: {len(col_df)}",
836
            fontdict=font_kws,
837
            transform=g.axes[0, 0].transAxes,
838
        )
839
        g.axes[0, 0].legend(loc="upper right")
840
841
        return g.axes[0, 0]
842
    return None
843
844
845
def missingval_plot(  # noqa: PLR0915
846
    data: pd.DataFrame,
847
    cmap: str = "PuBuGn",
848
    figsize: tuple = (20, 20),
849
    sort: bool = False,
850
    spine_color: str = "#EEEEEE",
851
) -> GridSpec:
852
    """Two-dimensional visualization of the missing values in a dataset.
853
854
    Parameters
855
    ----------
856
    data : pd.DataFrame
857
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
858
        is provided, the index/column information is used to label the plots
859
    cmap : str, optional
860
        Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \
861
        found in the matplotlib documentation, by default "PuBuGn"
862
    figsize : tuple, optional
863
        Use to control the figure size, by default (20, 20)
864
    sort : bool, optional
865
        Sort columns based on missing values in descending order and drop columns \
866
        without any missing values, by default False
867
    spine_color : str, optional
868
        Set to "None" to hide the spines on all plots or use any valid matplotlib \
869
        color argument, by default "#EEEEEE"
870
871
    Returns
872
    -------
873
    GridSpec
874
        gs: Figure with array of Axes objects
875
    """
876
    # Validate Inputs
877
    _validate_input_bool(sort, "sort")
878
879
    data = pd.DataFrame(data)
880
881
    if sort:
882
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
883
        final_cols = (
884
            mv_cols_sorted.drop(
885
                mv_cols_sorted[mv_cols_sorted.to_numpy() == 0].keys().tolist(),
886
            )
887
            .keys()
888
            .tolist()
889
        )
890
        data = data[final_cols]
891
        print("Displaying only columns with missing values.")
892
893
    # Identify missing values
894
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
895
    total_datapoints = data.shape[0] * data.shape[1]
896
897
    if mv_total == 0:
898
        print("No missing values found in the dataset.")
899
        return None
900
901
    # Create figure and axes
902
    fig = plt.figure(figsize=figsize)
903
    gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
904
    ax1 = fig.add_subplot(gs[:1, :5])
905
    ax2 = fig.add_subplot(gs[1:, :5])
906
    ax3 = fig.add_subplot(gs[:1, 5:])
907
    ax4 = fig.add_subplot(gs[1:, 5:])
908
909
    # ax1 - Barplot
910
    colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
911
    ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
912
    ax1.get_xaxis().set_visible(False)  # noqa: FBT003
913
    ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
914
    ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
915
    ax1.grid(linestyle=":", linewidth=1)
916
    ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=1))
917
    ax1.tick_params(axis="y", colors="#111111", length=1)
918
919
    # annotate values on top of the bars
920
    for rect, label in zip(ax1.patches, mv_cols, strict=True):
921
        height = rect.get_height()
922
        ax1.text(
923
            rect.get_x() + rect.get_width() / 2,
924
            height + max(np.log(1 + height / 6), 0.075),
925
            label,
926
            ha="center",
927
            va="bottom",
928
            rotation=90,
929
            alpha=0.5,
930
            fontsize="11",
931
        )
932
933
    ax1.set_frame_on(True)  # noqa: FBT003
934
    for _, spine in ax1.spines.items():
935
        spine.set_visible(True)  # noqa: FBT003
936
        spine.set_color(spine_color)
937
    ax1.spines["top"].set_color(None)
938
939
    # ax2 - Heatmap
940
    sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
941
    ax2.set_yticks(np.round(ax2.get_yticks()[::5], -1))
942
    ax2.set_yticklabels(ax2.get_yticks())
943
    ax2.set_xticklabels(
944
        ax2.get_xticklabels(),
945
        horizontalalignment="center",
946
        fontweight="light",
947
        fontsize="12",
948
    )
949
    ax2.tick_params(length=1, colors="#111111")
950
    for _, spine in ax2.spines.items():
951
        spine.set_visible(True)  # noqa: FBT003
952
        spine.set_color(spine_color)
953
954
    # ax3 - Summary
955
    fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
956
    ax3.get_xaxis().set_visible(False)  # noqa: FBT003
957
    ax3.get_yaxis().set_visible(False)  # noqa: FBT003
958
    ax3.set(frame_on=False)
959
960
    ax3.text(
961
        0.025,
962
        0.875,
963
        f"Total: {np.round(total_datapoints/1000,1)}K",
964
        transform=ax3.transAxes,
965
        fontdict=fontax3,
966
    )
967
    ax3.text(
968
        0.025,
969
        0.675,
970
        f"Missing: {np.round(mv_total/1000,1)}K",
971
        transform=ax3.transAxes,
972
        fontdict=fontax3,
973
    )
974
    ax3.text(
975
        0.025,
976
        0.475,
977
        f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
978
        transform=ax3.transAxes,
979
        fontdict=fontax3,
980
    )
981
    ax3.text(
982
        0.025,
983
        0.275,
984
        f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
985
        transform=ax3.transAxes,
986
        fontdict=fontax3,
987
    )
988
    ax3.text(
989
        0.025,
990
        0.075,
991
        f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
992
        transform=ax3.transAxes,
993
        fontdict=fontax3,
994
    )
995
996
    # ax4 - Scatter plot
997
    ax4.get_yaxis().set_visible(False)  # noqa: FBT003
998
    for _, spine in ax4.spines.items():
999
        spine.set_color(spine_color)
1000
    ax4.tick_params(axis="x", colors="#111111", length=1)
1001
1002
    ax4.scatter(
1003
        mv_rows,
1004
        range(len(mv_rows)),
1005
        s=mv_rows,
1006
        c=mv_rows,
1007
        cmap=cmap,
1008
        marker=".",
1009
        vmin=1,
1010
    )
1011
    ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
1012
    ax4.set_xlim(0, max(mv_rows) + 0.5)
1013
    ax4.grid(linestyle=":", linewidth=1)
1014
1015
    gs.figure.suptitle(
1016
        "Missing value plot",
1017
        x=0.45,
1018
        y=0.94,
1019
        fontsize=18,
1020
        color="#111111",
1021
    )
1022
1023
    return gs
1024