GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.
Test Failed
Push — main ( a48a6a...ecf315 )
by Andreas
01:41
created

klib.describe.corr_interactive_plot()   F

Complexity

Conditions 13

Size

Total Lines 216
Code Lines 72

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 13
eloc 72
nop 9
dl 0
loc 216
rs 3.8563
c 0
b 0
f 0

How to fix   Long Method    Complexity    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like klib.describe.corr_interactive_plot() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
"""Functions for descriptive analytics.
2
3
:author: Andreas Kanz
4
5
"""
6
from __future__ import annotations
7
8
from typing import Any
9
from typing import Literal
10
11
import matplotlib.pyplot as plt
12
import numpy as np
13
import pandas as pd
14
import plotly.graph_objects as go
15
import scipy
16
import seaborn as sns
17
from matplotlib import ticker
18
from matplotlib.colors import LinearSegmentedColormap
19
from matplotlib.colors import to_rgb
20
from matplotlib.gridspec import GridSpec  # noqa: TCH002
21
from screeninfo import get_monitors
22
from screeninfo import ScreenInfoError
23
24
from klib.utils import _corr_selector
25
from klib.utils import _missing_vals
26
from klib.utils import _validate_input_bool
27
from klib.utils import _validate_input_int
28
from klib.utils import _validate_input_num_data
29
from klib.utils import _validate_input_range
30
from klib.utils import _validate_input_smaller
31
from klib.utils import _validate_input_sum_larger
32
33
__all__ = [
34
    "cat_plot",
35
    "corr_interactive_plot",
36
    "corr_mat",
37
    "corr_plot",
38
    "dist_plot",
39
    "missingval_plot",
40
]
41
42
43
def cat_plot(  # noqa: C901, PLR0915
44
    data: pd.DataFrame,
45
    figsize: tuple[float, float] = (18, 18),
46
    top: int = 3,
47
    bottom: int = 3,
48
    bar_color_top: str = "#5ab4ac",
49
    bar_color_bottom: str = "#d8b365",
50
) -> GridSpec:
51
    """Two-dimensional visualization of number and frequency of categorical features.
52
53
    Parameters
54
    ----------
55
    data : pd.DataFrame
56
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
57
        is provided, the index/column information is used to label the plots
58
    figsize : tuple[float, float], optional
59
        Use to control the figure size, by default (18, 18)
60
    top : int, optional
61
        Show the "top" most frequent values in a column, by default 3
62
    bottom : int, optional
63
        Show the "bottom" most frequent values in a column, by default 3
64
    bar_color_top : str, optional
65
        Use to control the color of the bars indicating the most common values, by \
66
        default "#5ab4ac"
67
    bar_color_bottom : str, optional
68
        Use to control the color of the bars indicating the least common values, by \
69
        default "#d8b365"
70
71
    Returns
72
    -------
73
    Gridspec
74
        gs: Figure with array of Axes objects
75
    """
76
    # Validate Inputs
77
    _validate_input_int(top, "top")
78
    _validate_input_int(bottom, "bottom")
79
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
80
81
    data = pd.DataFrame(data).copy()
82
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
83
    data = data[cols]
84
85
    if len(cols) == 0:
86
        print("No columns with categorical data were detected.")
87
        return None
88
89
    for col in data.columns:
90
        if data[col].dtype.name in ("category", "string"):
91
            data[col] = data[col].astype("object")
92
93
    fig = plt.figure(figsize=figsize)
94
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
95
96
    for count, col in enumerate(cols):
97
        n_unique = data[col].nunique(dropna=True)
98
        value_counts = data[col].value_counts()
99
        lim_top, lim_bot = top, bottom
100
101
        if n_unique < top + bottom:
102
            if bottom > top:
103
                lim_top = min(int(n_unique // 2), top)
104
                lim_bot = n_unique - lim_top
105
            else:
106
                lim_bot = min(int(n_unique // 2), bottom)
107
                lim_top = n_unique - lim_bot
108
109
        value_counts_top = value_counts[:lim_top]
110
        value_counts_idx_top = value_counts_top.index.tolist()
111
        value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame()
112
        value_counts_idx_bot = value_counts_bot.index.tolist()
113
114
        if top == 0:
115
            value_counts_top = value_counts_idx_top = []
116
117
        if bottom == 0:
118
            value_counts_bot = value_counts_idx_bot = []
119
120
        data.loc[data[col].isin(value_counts_idx_top), col] = 10
121
        data.loc[data[col].isin(value_counts_idx_bot), col] = 0
122
        data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5  # noqa: PLR2004
123
        data[col] = data[col].rolling(2, min_periods=1).mean()
124
125
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
126
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
127
        sum_top = sum(value_counts_top)
128
        sum_bot = sum(value_counts_bot)
129
130
        # Barcharts
131
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
132
        ax_top.bar(
133
            value_counts_idx_top,
134
            value_counts_top,
135
            color=bar_color_top,
136
            width=0.85,
137
        )
138
        ax_top.bar(
139
            value_counts_idx_bot,
140
            value_counts_bot,
141
            color=bar_color_bottom,
142
            width=0.85,
143
        )
144
        ax_top.set(frame_on=False)
145
        ax_top.tick_params(axis="x", labelrotation=90)
146
147
        # Summary stats
148
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
149
        plt.subplots_adjust(hspace=0.075)
150
        ax_bottom.get_yaxis().set_visible(False)  # noqa: FBT003
151
        ax_bottom.get_xaxis().set_visible(False)  # noqa: FBT003
152
        ax_bottom.set(frame_on=False)
153
        ax_bottom.text(
154
            0,
155
            0,
156
            f"Unique values: {n_unique}\n\n"
157
            f"Top {lim_top}: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n"
158
            f"Bot {lim_bot}: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)",
159
            transform=ax_bottom.transAxes,
160
            color="#111111",
161
            fontsize=11,
162
        )
163
164
    # Heatmap
165
    color_bot_rgb = to_rgb(bar_color_bottom)
166
    color_white = to_rgb("#FFFFFF")
167
    color_top_rgb = to_rgb(bar_color_top)
168
    cat_plot_cmap = LinearSegmentedColormap.from_list(
169
        "cat_plot_cmap",
170
        [color_bot_rgb, color_white, color_top_rgb],
171
        N=200,
172
    )
173
    ax_hm = fig.add_subplot(gs[2:, :])
174
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm)
175
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[::5], -1))
176
    ax_hm.set_yticklabels(ax_hm.get_yticks())
177
    ax_hm.set_xticklabels(
178
        ax_hm.get_xticklabels(),
179
        horizontalalignment="center",
180
        fontweight="light",
181
        fontsize="medium",
182
    )
183
    ax_hm.tick_params(length=1, colors="#111111")
184
    gs.figure.suptitle(
185
        "Categorical data plot",
186
        x=0.5,
187
        y=0.91,
188
        fontsize=18,
189
        color="#111111",
190
    )
191
192
    return gs
193
194
195
def corr_mat(
196
    data: pd.DataFrame,
197
    split: Literal["pos", "neg", "high", "low"] | None = None,
198
    threshold: float = 0,
199
    target: pd.DataFrame | pd.Series | np.ndarray | str | None = None,
200
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
201
    colored: bool = True,
202
) -> pd.DataFrame | pd.Series:
203
    """Return a color-encoded correlation matrix.
204
205
    Parameters
206
    ----------
207
    data : pd.DataFrame
208
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
209
        is provided, the index/column information is used to label the plots
210
    split : Optional[Literal['pos', 'neg', 'high', 'low']], optional
211
        Type of split to be performed, by default None
212
        {None, "pos", "neg", "high", "low"}
213
    threshold : float, optional
214
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
215
        split = "high" or split = "low", in which case default is 0.3
216
    target : Optional[pd.DataFrame | str], optional
217
        Specify target for correlation. E.g. label column to generate only the \
218
        correlations between each feature and the label, by default None
219
    method : Literal['pearson', 'spearman', 'kendall'], optional
220
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
221
        * pearson: measures linear relationships and requires normally distributed \
222
            and homoscedastic data.
223
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
224
        * kendall: ranked/ordinal correlation, measures monotonic relationships. \
225
            Computationally more expensive but more robust in smaller dataets than \
226
            "spearman"
227
    colored : bool, optional
228
        If True the negative values in the correlation matrix are colored in red, by \
229
        default True
230
231
    Returns
232
    -------
233
    pd.DataFrame | pd.Styler
234
        If colored = True - corr: Pandas Styler object
235
        If colored = False - corr: Pandas DataFrame
236
    """
237
    # Validate Inputs
238
    _validate_input_range(threshold, "threshold", -1, 1)
239
    _validate_input_bool(colored, "colored")
240
241
    def color_negative_red(val: float) -> str:
242
        color = "#FF3344" if val < 0 else None
243
        return f"color: {color}"
244
245
    data = pd.DataFrame(data)
246
247
    _validate_input_num_data(data, "data")
248
249
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
250
        target_data = []
251
        if isinstance(target, str):
252
            target_data = data[target]
253
            data = data.drop(target, axis=1)
254
255
        elif isinstance(target, (list, pd.Series, np.ndarray)):
256
            target_data = pd.Series(target)
257
            target = target_data.name
258
259
        corr = pd.DataFrame(
260
            data.corrwith(target_data, method=method, numeric_only=True),
261
        )
262
        corr = corr.sort_values(corr.columns[0], ascending=False)
263
        corr.columns = [target]
264
265
    else:
266
        corr = data.corr(method=method, numeric_only=True)
267
268
    corr = _corr_selector(corr, split=split, threshold=threshold)
269
270
    if colored:
271
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
272
    return corr
273
274
275
def corr_plot(
276
    data: pd.DataFrame,
277
    split: Literal["pos", "neg", "high", "low"] | None = None,
278
    threshold: float = 0,
279
    target: pd.Series | str | None = None,
280
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
281
    cmap: str = "BrBG",
282
    figsize: tuple[float, float] = (12, 10),
283
    annot: bool = True,
284
    dev: bool = False,
285
    **kwargs,  # noqa: ANN003
286
) -> plt.Axes:
287
    """2D visualization of the correlation between feature-columns excluding NA values.
288
289
    Parameters
290
    ----------
291
    data : pd.DataFrame
292
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
293
        is provided, the index/column information is used to label the plots
294
    split : Optional[str], optional
295
        Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \
296
        None
297
            * None: visualize all correlations between the feature-columns
298
            * pos: visualize all positive correlations between the feature-columns \
299
                above the threshold
300
            * neg: visualize all negative correlations between the feature-columns \
301
                below the threshold
302
            * high: visualize all correlations between the feature-columns for \
303
                which abs (corr) > threshold is True
304
            * low: visualize all correlations between the feature-columns for which \
305
                abs(corr) < threshold is True
306
307
    threshold : float, optional
308
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
309
            split = "high" or split = "low", in which case default is 0.3
310
    target : Optional[pd.Series | str], optional
311
        Specify target for correlation. E.g. label column to generate only the \
312
        correlations between each feature and the label, by default None
313
    method : Literal['pearson', 'spearman', 'kendall'], optional
314
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
315
            * pearson: measures linear relationships and requires normally \
316
                distributed and homoscedastic data.
317
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
318
            * kendall: ranked/ordinal correlation, measures monotonic relationships. \
319
                Computationally more expensive but more robust in smaller dataets \
320
                than "spearman".
321
322
    cmap : str, optional
323
        The mapping from data values to color space, matplotlib colormap name or \
324
        object, or list of colors, by default "BrBG"
325
    figsize : tuple[float, float], optional
326
        Use to control the figure size, by default (12, 10)
327
    annot : bool, optional
328
        Use to show or hide annotations, by default True
329
    dev : bool, optional
330
        Display figure settings in the plot by setting dev = True. If False, the \
331
        settings are not displayed, by default False
332
333
    kwargs : optional
334
        Additional elements to control the visualization of the plot, e.g.:
335
336
            * mask: bool, default True
337
                If set to False the entire correlation matrix, including the upper \
338
                triangle is shown. Set dev = False in this case to avoid overlap.
339
            * vmax: float, default is calculated from the given correlation \
340
                coefficients.
341
                Value between -1 or vmin <= vmax <= 1, limits the range of the cbar.
342
            * vmin: float, default is calculated from the given correlation \
343
                coefficients.
344
                Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar.
345
            * linewidths: float, default 0.5
346
                Controls the line-width inbetween the squares.
347
            * annot_kws: dict, default {"size" : 10}
348
                Controls the font size of the annotations. Only available when \
349
                annot = True.
350
            * cbar_kws: dict, default {"shrink": .95, "aspect": 30}
351
                Controls the size of the colorbar.
352
            * Many more kwargs are available, i.e. "alpha" to control blending, or \
353
                options to adjust labels, ticks ...
354
355
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
356
357
    Returns
358
    -------
359
    ax: matplotlib Axes
360
        Returns the Axes object with the plot for further tweaking.
361
    """
362
    # Validate Inputs
363
    _validate_input_range(threshold, "threshold", -1, 1)
364
    _validate_input_bool(annot, "annot")
365
    _validate_input_bool(dev, "dev")
366
367
    data = pd.DataFrame(data)
368
369
    corr = corr_mat(
370
        data,
371
        split=split,
372
        threshold=threshold,
373
        target=target,
374
        method=method,
375
        colored=False,
376
    )
377
378
    mask = np.zeros_like(corr, dtype=bool)
379
380
    if target is None:
381
        mask = np.triu(np.ones_like(corr, dtype=bool))
382
383
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
384
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
385
386
    fig, ax = plt.subplots(figsize=figsize)
387
388
    # Specify kwargs for the heatmap
389
    kwargs = {
390
        "mask": mask,
391
        "cmap": cmap,
392
        "annot": annot,
393
        "vmax": vmax,
394
        "vmin": vmin,
395
        "linewidths": 0.5,
396
        "annot_kws": {"size": 10},
397
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
398
        **kwargs,
399
    }
400
401
    # Draw heatmap with mask and default settings
402
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
403
404
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
405
406
    # Settings
407
    if dev:
408
        fig.suptitle(
409
            f"\
410
            Settings (dev-mode): \n\
411
            - split-mode: {split} \n\
412
            - threshold: {threshold} \n\
413
            - method: {method} \n\
414
            - annotations: {annot} \n\
415
            - cbar: \n\
416
                - vmax: {vmax} \n\
417
                - vmin: {vmin} \n\
418
            - linewidths: {kwargs['linewidths']} \n\
419
            - annot_kws: {kwargs['annot_kws']} \n\
420
            - cbar_kws: {kwargs['cbar_kws']}",
421
            fontsize=12,
422
            color="gray",
423
            x=0.35,
424
            y=0.85,
425
            ha="left",
426
        )
427
428
    return ax
429
430
431
def corr_interactive_plot(
432
    data: pd.DataFrame,
433
    split: Literal["pos", "neg", "high", "low"] | None = None,
434
    threshold: float = 0.0,
435
    target: pd.Series | str | None = None,
436
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
437
    cmap: str = "BrBG",
438
    figsize: tuple[float, float] = (12, 10),
439
    annot: bool = True,
440
    **kwargs,  # noqa: ANN003
441
) -> go.Figure:
442
    """Interactive 2D visualization of the correlation between feature-columns.
443
444
    Parameters
445
    ----------
446
    data : pd.DataFrame
447
        2D dataset that can be coerced into a Pandas DataFrame. If a
448
        Pandas DataFrame is provided, the index/column information is
449
        used to label the plots.
450
451
    split : Optional[str], optional
452
        Type of split to be performed
453
        {None, "pos", "neg", "high", "low"}, by default None
454
455
        - None: visualize all correlations between the feature-columns
456
457
        - pos: visualize all positive correlations between the
458
            feature-columns above the threshold
459
460
        - neg: visualize all negative correlations between the
461
            feature-columns below the threshold
462
463
        - high: visualize all correlations between the
464
            feature-columns for which abs(corr) > threshold is True
465
466
        - low: visualize all correlations between the
467
            feature-columns for which abs(corr) < threshold is True
468
469
    threshold : float, optional
470
        Value between 0 and 1 to set the correlation threshold,
471
        by default 0 unless split = "high" or split = "low", in
472
        which case the default is 0.3
473
474
    target : Optional[pd.Series | str], optional
475
        Specify a target for correlation. For example, the label column
476
        to generate only the correlations between each feature and the
477
        label, by default None
478
479
    method : Literal['pearson', 'spearman', 'kendall'], optional
480
        Method for correlation calculation:
481
        {"pearson", "spearman", "kendall"}, by default "pearson"
482
483
        - pearson: measures linear relationships and requires normally
484
            distributed and homoscedastic data.
485
        - spearman: ranked/ordinal correlation, measures monotonic
486
            relationships.
487
        - kendall: ranked/ordinal correlation, measures monotonic
488
            relationships. Computationally more expensive but more
489
            robust in smaller datasets than "spearman".
490
491
    cmap : str, optional
492
        The mapping from data values to color space, plotly
493
        colormap name or object, or list of colors, by default "BrBG"
494
495
    figsize : tuple[float, float], optional
496
        Use to control the figure size, by default (12, 10)
497
498
    annot : bool, optional
499
        Use to show or hide annotations, by default True
500
501
    **kwargs : optional
502
        Additional elements to control the visualization of the plot.
503
            These additional arguments will be passed to the `go.Heatmap`
504
            function in Plotly.
505
506
        Specific kwargs used in this function:
507
508
        - colorscale: str or list, optional
509
            The colorscale to be used for the heatmap. It controls the
510
            mapping of data values to colors in the heatmap.
511
512
        - zmax: float, optional
513
            The maximum value of the color scale. It limits the upper
514
            range of the colorbar displayed on the heatmap.
515
516
        - zmin: float, optional
517
            The minimum value of the color scale. It limits the lower
518
            range of the colorbar displayed on the heatmap.
519
520
        - text: pd.DataFrame, optional
521
            A DataFrame containing text to display on the heatmap. This
522
            text will be shown on the heatmap cells corresponding to the
523
            correlation values.
524
525
        - texttemplate: str, optional
526
            A text template string to format the text display on the
527
            heatmap. This allows you to customize how the text appears,
528
            including the display of the correlation values.
529
530
        - textfont: dict, optional
531
            A dictionary specifying the font properties for the text on
532
            the heatmap. You can customize the font size, color, family,
533
            etc., for the text annotations.
534
535
        - x: list, optional
536
            The list of column names for the x-axis of the heatmap. It
537
            allows you to customize the labels displayed on the x-axis.
538
539
        - y: list, optional
540
            The list of row names for the y-axis of the heatmap. It
541
            allows you to customize the labels displayed on the y-axis.
542
543
        - z: pd.DataFrame, optional
544
            The 2D array representing the correlation matrix to be
545
            visualized. This is the core data for generating the heatmap,
546
            containing the correlation values.
547
548
        - Many more kwargs are available, e.g., "hovertemplate" to control
549
            the legend hover template, or options to adjust the borderwidth
550
            and opacity of the heatmap. For a comprehensive list of
551
            available kwargs, please refer to the Plotly Heatmap documentation.
552
553
        Kwargs can be supplied through a dictionary of key-value pairs
554
        (see above) and can be found in Plotly Heatmap documentation.
555
556
    Returns
557
    -------
558
    heatmap : plotly.graph_objs._figure.Figure
559
        A Plotly Figure object representing the heatmap visualization of
560
        feature correlations.
561
    """
562
    # Validate Inputs
563
    _validate_input_range(threshold, "threshold", -1, 1)
564
    _validate_input_bool(annot, "annot")
565
566
    data = pd.DataFrame(data).iloc[:, ::-1]
567
568
    corr = corr_mat(
569
        data,
570
        split=split,
571
        threshold=threshold,
572
        target=target,
573
        method=method,
574
        colored=False,
575
    )
576
577
    mask = np.zeros_like(corr, dtype=bool)
578
579
    if target is None:
580
        mask = np.triu(np.ones_like(corr, dtype=bool))
581
        np.fill_diagonal(corr.to_numpy(), np.nan)
582
        corr = corr.where(mask == 1)
583
    else:
584
        corr = corr.iloc[::-1, :]
585
586
    vmax = np.round(np.nanmax(corr) - 0.05, 2)
587
    vmin = np.round(np.nanmin(corr) + 0.05, 2)
588
589
    vmax = -vmin if split == "neg" else vmax
590
    vmin = -vmax if split == "pos" else vmin
591
592
    vtext = corr.round(2).fillna("") if annot else None
593
594
    # Specify kwargs for the heatmap
595
    kwargs = {
596
        "colorscale": cmap,
597
        "zmax": vmax,
598
        "zmin": vmin,
599
        "text": vtext,
600
        "texttemplate": "%{text}",
601
        "textfont": {"size": 12},
602
        "x": corr.columns,
603
        "y": corr.index,
604
        "z": corr,
605
        **kwargs,
606
    }
607
608
    # Draw heatmap with masked corr and default settings
609
    heatmap = go.Figure(
610
        data=go.Heatmap(
611
            hoverongaps=False,
612
            xgap=1,
613
            ygap=1,
614
            **kwargs,
615
        ),
616
    )
617
618
    dpi = None
619
    try:
620
        for monitor in get_monitors():
621
            if monitor.is_primary:
622
                if monitor.width_mm is None or monitor.height_mm is None:
623
                    continue
624
                dpi = monitor.width / (monitor.width_mm / 25.4)
625
                break
626
627
        if dpi is None:
628
            monitor = get_monitors()[0]
629
            if monitor.width_mm is None or monitor.height_mm is None:
630
                dpi = 96  # more or less arbitrary default value
631
            else:
632
                dpi = monitor.width / (monitor.width_mm / 25.4)
633
    except ScreenInfoError:
634
        dpi = 96
635
636
    heatmap.update_layout(
637
        title=f"Feature-correlation ({method})",
638
        title_font={"size": 24},
639
        title_x=0.5,
640
        autosize=True,
641
        width=figsize[0] * dpi,
642
        height=(figsize[1] + 1) * dpi,
643
        xaxis={"autorange": "reversed"},
644
    )
645
646
    return heatmap
647
648
649
def dist_plot(
650
    data: pd.DataFrame,
651
    mean_color: str = "orange",
652
    size: int = 3,
653
    fill_range: tuple = (0.025, 0.975),
654
    showall: bool = False,
655
    kde_kws: dict[str, Any] | None = None,
656
    rug_kws: dict[str, Any] | None = None,
657
    fill_kws: dict[str, Any] | None = None,
658
    font_kws: dict[str, Any] | None = None,
659
) -> None | Any:  # noqa: ANN401
660
    """2D visualization of the distribution of non binary numerical features.
661
662
    Parameters
663
    ----------
664
    data : pd.DataFrame
665
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
666
        is provided, the index/column information is used to label the plots
667
    mean_color : str, optional
668
        Color of the vertical line indicating the mean of the data, by default "orange"
669
    size : float, optional
670
        Controls the plot size, by default 3
671
    fill_range : tuple, optional
672
        Set the quantiles for shading. Default spans 95% of the data, which is about \
673
        two std. deviations above and below the mean, by default (0.025, 0.975)
674
    showall : bool, optional
675
        Set to True to remove the output limit of 20 plots, by default False
676
    kde_kws : dict[str, Any], optional
677
        Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \
678
        "linewidth": 1.5, "bw_adjust": 0.8}
679
    rug_kws : dict[str, Any], optional
680
        Keyword arguments for rugplot(), by default {"color": "#ff3333", \
681
        "alpha": 0.15, "lw": 3, "height": 0.075}
682
    fill_kws : dict[str, Any], optional
683
        Keyword arguments to control the fill, by default {"color": "#80d4ff", \
684
        "alpha": 0.2}
685
    font_kws : dict[str, Any], optional
686
        Keyword arguments to control the font, by default {"color":  "#111111", \
687
        "weight": "normal", "size": 11}
688
689
    Returns
690
    -------
691
    ax: matplotlib Axes
692
        Returns the Axes object with the plot for further tweaking.
693
    """
694
    # Validate Inputs
695
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
696
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
697
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
698
    _validate_input_bool(showall, "showall")
699
700
    # Handle dictionary defaults
701
    kde_kws = (
702
        {"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8}
703
        if kde_kws is None
704
        else kde_kws.copy()
705
    )
706
    rug_kws = (
707
        {"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075}
708
        if rug_kws is None
709
        else rug_kws.copy()
710
    )
711
    fill_kws = (
712
        {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy()
713
    )
714
    font_kws = (
715
        {"color": "#111111", "weight": "normal", "size": 11}
716
        if font_kws is None
717
        else font_kws.copy()
718
    )
719
720
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
721
    df = data.copy()  # noqa: PD901
722
    data = data.loc[:, data.nunique() > 2]  # noqa: PLR2004
723
    if data.shape[0] > 10000:  # noqa: PLR2004
724
        data = data.sample(n=10000, random_state=408)
725
        print(
726
            "Large dataset detected, using 10000 random samples for the plots. Summary"
727
            " statistics are still based on the entire dataset.",
728
        )
729
    cols = list(data.select_dtypes(include=["number"]).columns)
730
    data = data[cols]
731
732
    if not cols:
733
        print("No columns with numeric data were detected.")
734
        return None
735
736
    if len(cols) >= 20 and not showall:  # noqa: PLR2004
737
        print(
738
            "Note: The number of non binary numerical features is very large "
739
            f"({len(cols)}), please consider splitting the data. Showing plots for "
740
            "the first 20 numerical features. Override this by setting showall=True.",
741
        )
742
        cols = cols[:20]
743
744
    for col in cols:
745
        col_data = data[col].dropna(axis=0)
746
        col_df = df[col].dropna(axis=0)
747
748
        g = sns.displot(
749
            col_data,
750
            kind="kde",
751
            rug=True,
752
            height=size,
753
            aspect=5,
754
            legend=False,
755
            rug_kws=rug_kws,
756
            **kde_kws,
757
        )
758
759
        # Vertical lines and fill
760
        x, y = g.axes[0, 0].lines[0].get_xydata().T
761
        g.axes[0, 0].fill_between(
762
            x,
763
            y,
764
            where=(
765
                (x >= np.quantile(col_df, fill_range[0]))
766
                & (x <= np.quantile(col_df, fill_range[1]))
767
            ),
768
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
769
            **fill_kws,
770
        )
771
772
        mean = np.mean(col_df)
773
        std = scipy.stats.tstd(col_df)
774
        g.axes[0, 0].vlines(
775
            x=mean,
776
            ymin=0,
777
            ymax=np.interp(mean, x, y),
778
            ls="dotted",
779
            color=mean_color,
780
            lw=2,
781
            label="mean",
782
        )
783
        g.axes[0, 0].vlines(
784
            x=np.median(col_df),
785
            ymin=0,
786
            ymax=np.interp(np.median(col_df), x, y),
787
            ls=":",
788
            color=".3",
789
            label="median",
790
        )
791
        g.axes[0, 0].vlines(
792
            x=[mean - std, mean + std],
793
            ymin=0,
794
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
795
            ls=":",
796
            color=".5",
797
            label="\u03BC \u00B1 \u03C3",
798
        )
799
800
        g.axes[0, 0].set_ylim(0)
801
        g.axes[0, 0].set_xlim(
802
            g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05,
803
            g.axes[0, 0].get_xlim()[1] * 1.03,
804
        )
805
806
        # Annotations and legend
807
        g.axes[0, 0].text(
808
            0.005,
809
            0.9,
810
            f"Mean: {mean:.2f}",
811
            fontdict=font_kws,
812
            transform=g.axes[0, 0].transAxes,
813
        )
814
        g.axes[0, 0].text(
815
            0.005,
816
            0.7,
817
            f"Std. dev: {std:.2f}",
818
            fontdict=font_kws,
819
            transform=g.axes[0, 0].transAxes,
820
        )
821
        g.axes[0, 0].text(
822
            0.005,
823
            0.5,
824
            f"Skew: {scipy.stats.skew(col_df):.2f}",
825
            fontdict=font_kws,
826
            transform=g.axes[0, 0].transAxes,
827
        )
828
        g.axes[0, 0].text(
829
            0.005,
830
            0.3,
831
            f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}",  # Excess Kurtosis
832
            fontdict=font_kws,
833
            transform=g.axes[0, 0].transAxes,
834
        )
835
        g.axes[0, 0].text(
836
            0.005,
837
            0.1,
838
            f"Count: {len(col_df)}",
839
            fontdict=font_kws,
840
            transform=g.axes[0, 0].transAxes,
841
        )
842
        g.axes[0, 0].legend(loc="upper right")
843
844
        return g.axes[0, 0]
845
    return None
846
847
848
def missingval_plot(  # noqa: PLR0915
849
    data: pd.DataFrame,
850
    cmap: str = "PuBuGn",
851
    figsize: tuple = (20, 20),
852
    sort: bool = False,
853
    spine_color: str = "#EEEEEE",
854
) -> GridSpec:
855
    """Two-dimensional visualization of the missing values in a dataset.
856
857
    Parameters
858
    ----------
859
    data : pd.DataFrame
860
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
861
        is provided, the index/column information is used to label the plots
862
    cmap : str, optional
863
        Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \
864
        found in the matplotlib documentation, by default "PuBuGn"
865
    figsize : tuple, optional
866
        Use to control the figure size, by default (20, 20)
867
    sort : bool, optional
868
        Sort columns based on missing values in descending order and drop columns \
869
        without any missing values, by default False
870
    spine_color : str, optional
871
        Set to "None" to hide the spines on all plots or use any valid matplotlib \
872
        color argument, by default "#EEEEEE"
873
874
    Returns
875
    -------
876
    GridSpec
877
        gs: Figure with array of Axes objects
878
    """
879
    # Validate Inputs
880
    _validate_input_bool(sort, "sort")
881
882
    data = pd.DataFrame(data)
883
884
    if sort:
885
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
886
        final_cols = (
887
            mv_cols_sorted.drop(
888
                mv_cols_sorted[mv_cols_sorted.to_numpy() == 0].keys().tolist(),
889
            )
890
            .keys()
891
            .tolist()
892
        )
893
        data = data[final_cols]
894
        print("Displaying only columns with missing values.")
895
896
    # Identify missing values
897
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
898
    total_datapoints = data.shape[0] * data.shape[1]
899
900
    if mv_total == 0:
901
        print("No missing values found in the dataset.")
902
        return None
903
904
    # Create figure and axes
905
    fig = plt.figure(figsize=figsize)
906
    gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
907
    ax1 = fig.add_subplot(gs[:1, :5])
908
    ax2 = fig.add_subplot(gs[1:, :5])
909
    ax3 = fig.add_subplot(gs[:1, 5:])
910
    ax4 = fig.add_subplot(gs[1:, 5:])
911
912
    # ax1 - Barplot
913
    colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
914
    ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
915
    ax1.get_xaxis().set_visible(False)  # noqa: FBT003
916
    ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
917
    ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
918
    ax1.grid(linestyle=":", linewidth=1)
919
    ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=1))
920
    ax1.tick_params(axis="y", colors="#111111", length=1)
921
922
    # annotate values on top of the bars
923
    for rect, label in zip(ax1.patches, mv_cols, strict=True):
924
        height = rect.get_height()
925
        ax1.text(
926
            rect.get_x() + rect.get_width() / 2,
927
            height + max(np.log(1 + height / 6), 0.075),
928
            label,
929
            ha="center",
930
            va="bottom",
931
            rotation=90,
932
            alpha=0.5,
933
            fontsize="11",
934
        )
935
936
    ax1.set_frame_on(True)  # noqa: FBT003
937
    for _, spine in ax1.spines.items():
938
        spine.set_visible(True)  # noqa: FBT003
939
        spine.set_color(spine_color)
940
    ax1.spines["top"].set_color(None)
941
942
    # ax2 - Heatmap
943
    sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
944
    ax2.set_yticks(np.round(ax2.get_yticks()[::5], -1))
945
    ax2.set_yticklabels(ax2.get_yticks())
946
    ax2.set_xticklabels(
947
        ax2.get_xticklabels(),
948
        horizontalalignment="center",
949
        fontweight="light",
950
        fontsize="12",
951
    )
952
    ax2.tick_params(length=1, colors="#111111")
953
    for _, spine in ax2.spines.items():
954
        spine.set_visible(True)  # noqa: FBT003
955
        spine.set_color(spine_color)
956
957
    # ax3 - Summary
958
    fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
959
    ax3.get_xaxis().set_visible(False)  # noqa: FBT003
960
    ax3.get_yaxis().set_visible(False)  # noqa: FBT003
961
    ax3.set(frame_on=False)
962
963
    ax3.text(
964
        0.025,
965
        0.875,
966
        f"Total: {np.round(total_datapoints/1000,1)}K",
967
        transform=ax3.transAxes,
968
        fontdict=fontax3,
969
    )
970
    ax3.text(
971
        0.025,
972
        0.675,
973
        f"Missing: {np.round(mv_total/1000,1)}K",
974
        transform=ax3.transAxes,
975
        fontdict=fontax3,
976
    )
977
    ax3.text(
978
        0.025,
979
        0.475,
980
        f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
981
        transform=ax3.transAxes,
982
        fontdict=fontax3,
983
    )
984
    ax3.text(
985
        0.025,
986
        0.275,
987
        f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
988
        transform=ax3.transAxes,
989
        fontdict=fontax3,
990
    )
991
    ax3.text(
992
        0.025,
993
        0.075,
994
        f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
995
        transform=ax3.transAxes,
996
        fontdict=fontax3,
997
    )
998
999
    # ax4 - Scatter plot
1000
    ax4.get_yaxis().set_visible(False)  # noqa: FBT003
1001
    for _, spine in ax4.spines.items():
1002
        spine.set_color(spine_color)
1003
    ax4.tick_params(axis="x", colors="#111111", length=1)
1004
1005
    ax4.scatter(
1006
        mv_rows,
1007
        range(len(mv_rows)),
1008
        s=mv_rows,
1009
        c=mv_rows,
1010
        cmap=cmap,
1011
        marker=".",
1012
        vmin=1,
1013
    )
1014
    ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
1015
    ax4.set_xlim(0, max(mv_rows) + 0.5)
1016
    ax4.grid(linestyle=":", linewidth=1)
1017
1018
    gs.figure.suptitle(
1019
        "Missing value plot",
1020
        x=0.45,
1021
        y=0.94,
1022
        fontsize=18,
1023
        color="#111111",
1024
    )
1025
1026
    return gs
1027