GitHub Access Token became invalid

It seems like the GitHub access token used for retrieving details about this repository from GitHub became invalid. This might prevent certain types of inspections from being run (in particular, everything related to pull requests).
Please ask an admin of your repository to re-new the access token on this website.

klib.describe   B
last analyzed

Complexity

Total Complexity 52

Size/Duplication

Total Lines 1044
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 534
dl 0
loc 1044
rs 7.44
c 0
b 0
f 0
wmc 52

6 Functions

Rating   Name   Duplication   Size   Complexity  
F corr_interactive_plot() 0 227 15
F dist_plot() 0 196 11
D cat_plot() 0 151 10
B corr_plot() 0 156 3
C missingval_plot() 0 180 7
B corr_mat() 0 80 6

How to fix   Complexity   

Complexity

Complex classes like klib.describe often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Functions for descriptive analytics.
2
3
:author: Andreas Kanz
4
5
"""
6
7
from __future__ import annotations
8
9
from typing import Any
10
from typing import Literal
11
12
import matplotlib.pyplot as plt
13
import numpy as np
14
import pandas as pd
15
import plotly.graph_objects as go
16
import scipy
17
import seaborn as sns
18
from matplotlib import ticker
19
from matplotlib.colors import LinearSegmentedColormap
20
from matplotlib.colors import to_rgb
21
from matplotlib.gridspec import GridSpec  # noqa: TCH002
22
from screeninfo import get_monitors
23
from screeninfo import ScreenInfoError
24
25
from klib.utils import _corr_selector
26
from klib.utils import _missing_vals
27
from klib.utils import _validate_input_bool
28
from klib.utils import _validate_input_int
29
from klib.utils import _validate_input_num_data
30
from klib.utils import _validate_input_range
31
from klib.utils import _validate_input_smaller
32
from klib.utils import _validate_input_sum_larger
33
34
__all__ = [
35
    "cat_plot",
36
    "corr_interactive_plot",
37
    "corr_mat",
38
    "corr_plot",
39
    "dist_plot",
40
    "missingval_plot",
41
]
42
43
44
def cat_plot(  # noqa: C901, PLR0915
45
    data: pd.DataFrame,
46
    figsize: tuple[float, float] = (18, 18),
47
    top: int = 3,
48
    bottom: int = 3,
49
    bar_color_top: str = "#5ab4ac",
50
    bar_color_bottom: str = "#d8b365",
51
) -> GridSpec:
52
    """Two-dimensional visualization of number and frequency of categorical features.
53
54
    Parameters
55
    ----------
56
    data : pd.DataFrame
57
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
58
        is provided, the index/column information is used to label the plots
59
    figsize : tuple[float, float], optional
60
        Use to control the figure size, by default (18, 18)
61
    top : int, optional
62
        Show the "top" most frequent values in a column, by default 3
63
    bottom : int, optional
64
        Show the "bottom" most frequent values in a column, by default 3
65
    bar_color_top : str, optional
66
        Use to control the color of the bars indicating the most common values, by \
67
        default "#5ab4ac"
68
    bar_color_bottom : str, optional
69
        Use to control the color of the bars indicating the least common values, by \
70
        default "#d8b365"
71
72
    Returns
73
    -------
74
    Gridspec
75
        gs: Figure with array of Axes objects
76
77
    """
78
    # Validate Inputs
79
    _validate_input_int(top, "top")
80
    _validate_input_int(bottom, "bottom")
81
    _validate_input_sum_larger(1, "top and bottom", top, bottom)
82
83
    data = pd.DataFrame(data).copy()
84
    cols = data.select_dtypes(exclude=["number"]).columns.tolist()
85
    data = data[cols]
86
87
    if len(cols) == 0:
88
        print("No columns with categorical data were detected.")
89
        return None
90
91
    for col in data.columns:
92
        if data[col].dtype.name in ("category", "string"):
93
            data[col] = data[col].astype("object")
94
95
    fig = plt.figure(figsize=figsize)
96
    gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21)
97
98
    for count, col in enumerate(cols):
99
        n_unique = data[col].nunique(dropna=True)
100
        value_counts = data[col].value_counts()
101
        lim_top, lim_bot = top, bottom
102
103
        if n_unique < top + bottom:
104
            if bottom > top:
105
                lim_top = min(int(n_unique // 2), top)
106
                lim_bot = n_unique - lim_top
107
            else:
108
                lim_bot = min(int(n_unique // 2), bottom)
109
                lim_top = n_unique - lim_bot
110
111
        value_counts_top = value_counts[:lim_top]
112
        value_counts_idx_top = value_counts_top.index.tolist()
113
        value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame()
114
        value_counts_idx_bot = value_counts_bot.index.tolist()
115
116
        if top == 0:
117
            value_counts_top = value_counts_idx_top = []
118
119
        if bottom == 0:
120
            value_counts_bot = value_counts_idx_bot = []
121
122
        data.loc[data[col].isin(value_counts_idx_top), col] = 10
123
        data.loc[data[col].isin(value_counts_idx_bot), col] = 0
124
        data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5  # noqa: PLR2004
125
        data[col] = data[col].rolling(2, min_periods=1).mean()
126
127
        value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top]
128
        value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot]
129
        sum_top = sum(value_counts_top)
130
        sum_bot = sum(value_counts_bot)
131
132
        # Barcharts
133
        ax_top = fig.add_subplot(gs[:1, count : count + 1])
134
        ax_top.bar(
135
            value_counts_idx_top,
136
            value_counts_top,
137
            color=bar_color_top,
138
            width=0.85,
139
        )
140
        ax_top.bar(
141
            value_counts_idx_bot,
142
            value_counts_bot,
143
            color=bar_color_bottom,
144
            width=0.85,
145
        )
146
        ax_top.set(frame_on=False)
147
        ax_top.tick_params(axis="x", labelrotation=90)
148
149
        # Summary stats
150
        ax_bottom = fig.add_subplot(gs[1:2, count : count + 1])
151
        plt.subplots_adjust(hspace=0.075)
152
        ax_bottom.get_yaxis().set_visible(False)
153
        ax_bottom.get_xaxis().set_visible(False)
154
        ax_bottom.set(frame_on=False)
155
        ax_bottom.text(
156
            0,
157
            0,
158
            f"Unique values: {n_unique}\n\n"
159
            f"Top {lim_top}: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n"
160
            f"Bot {lim_bot}: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)",
161
            transform=ax_bottom.transAxes,
162
            color="#111111",
163
            fontsize=11,
164
        )
165
166
    # Heatmap
167
    color_bot_rgb = to_rgb(bar_color_bottom)
168
    color_white = to_rgb("#FFFFFF")
169
    color_top_rgb = to_rgb(bar_color_top)
170
    cat_plot_cmap = LinearSegmentedColormap.from_list(
171
        "cat_plot_cmap",
172
        [color_bot_rgb, color_white, color_top_rgb],
173
        N=200,
174
    )
175
    ax_hm = fig.add_subplot(gs[2:, :])
176
    sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm)
177
    ax_hm.set_yticks(np.round(ax_hm.get_yticks()[::5], -1))
178
    ax_hm.set_yticklabels(ax_hm.get_yticks())
179
    ax_hm.set_xticklabels(
180
        ax_hm.get_xticklabels(),
181
        horizontalalignment="center",
182
        fontweight="light",
183
        fontsize="medium",
184
    )
185
    ax_hm.tick_params(length=1, colors="#111111")
186
    gs.figure.suptitle(
187
        "Categorical data plot",
188
        x=0.5,
189
        y=0.91,
190
        fontsize=18,
191
        color="#111111",
192
    )
193
194
    return gs
195
196
197
def corr_mat(
198
    data: pd.DataFrame,
199
    split: Literal["pos", "neg", "high", "low"] | None = None,
200
    threshold: float = 0,
201
    target: pd.DataFrame | pd.Series | np.ndarray | str | None = None,
202
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
203
    *,
204
    colored: bool = True,
205
) -> pd.DataFrame | pd.Series:
206
    """Return a color-encoded correlation matrix.
207
208
    Parameters
209
    ----------
210
    data : pd.DataFrame
211
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
212
        is provided, the index/column information is used to label the plots
213
    split : Optional[Literal['pos', 'neg', 'high', 'low']], optional
214
        Type of split to be performed, by default None
215
        {None, "pos", "neg", "high", "low"}
216
    threshold : float, optional
217
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
218
        split = "high" or split = "low", in which case default is 0.3
219
    target : Optional[pd.DataFrame | str], optional
220
        Specify target for correlation. E.g. label column to generate only the \
221
        correlations between each feature and the label, by default None
222
    method : Literal['pearson', 'spearman', 'kendall'], optional
223
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
224
        * pearson: measures linear relationships and requires normally distributed \
225
            and homoscedastic data.
226
        * spearman: ranked/ordinal correlation, measures monotonic relationships.
227
        * kendall: ranked/ordinal correlation, measures monotonic relationships. \
228
            Computationally more expensive but more robust in smaller dataets than \
229
            "spearman"
230
    colored : bool, optional
231
        If True the negative values in the correlation matrix are colored in red, by \
232
        default True
233
234
    Returns
235
    -------
236
    pd.DataFrame | pd.Styler
237
        If colored = True - corr: Pandas Styler object
238
        If colored = False - corr: Pandas DataFrame
239
240
    """
241
    # Validate Inputs
242
    _validate_input_range(threshold, "threshold", -1, 1)
243
    _validate_input_bool(colored, "colored")
244
245
    def color_negative_red(val: float) -> str:
246
        color = "#FF3344" if val < 0 else None
247
        return f"color: {color}"
248
249
    data = pd.DataFrame(data)
250
251
    _validate_input_num_data(data, "data")
252
253
    if isinstance(target, (str, list, pd.Series, np.ndarray)):
254
        target_data = []
255
        if isinstance(target, str):
256
            target_data = data[target]
257
            data = data.drop(target, axis=1)
258
259
        elif isinstance(target, (list, pd.Series, np.ndarray)):
260
            target_data = pd.Series(target)
261
            target = target_data.name
262
263
        corr = pd.DataFrame(
264
            data.corrwith(target_data, method=method, numeric_only=True),
265
        )
266
        corr = corr.sort_values(corr.columns[0], ascending=False)
267
        corr.columns = [target]
268
269
    else:
270
        corr = data.corr(method=method, numeric_only=True)
271
272
    corr = _corr_selector(corr, split=split, threshold=threshold)
273
274
    if colored:
275
        return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-")
276
    return corr
277
278
279
def corr_plot(
280
    data: pd.DataFrame,
281
    split: Literal["pos", "neg", "high", "low"] | None = None,
282
    threshold: float = 0,
283
    target: pd.Series | str | None = None,
284
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
285
    cmap: str = "BrBG",
286
    figsize: tuple[float, float] = (12, 10),
287
    *,
288
    annot: bool = True,
289
    dev: bool = False,
290
    **kwargs,  # noqa: ANN003
291
) -> plt.Axes:
292
    """2D visualization of the correlation between feature-columns excluding NA values.
293
294
    Parameters
295
    ----------
296
    data : pd.DataFrame
297
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
298
        is provided, the index/column information is used to label the plots
299
    split : Optional[str], optional
300
        Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \
301
        None
302
            * None: visualize all correlations between the feature-columns
303
            * pos: visualize all positive correlations between the feature-columns \
304
                above the threshold
305
            * neg: visualize all negative correlations between the feature-columns \
306
                below the threshold
307
            * high: visualize all correlations between the feature-columns for \
308
                which abs (corr) > threshold is True
309
            * low: visualize all correlations between the feature-columns for which \
310
                abs(corr) < threshold is True
311
312
    threshold : float, optional
313
        Value between 0 and 1 to set the correlation threshold, by default 0 unless \
314
            split = "high" or split = "low", in which case default is 0.3
315
    target : Optional[pd.Series | str], optional
316
        Specify target for correlation. E.g. label column to generate only the \
317
        correlations between each feature and the label, by default None
318
    method : Literal['pearson', 'spearman', 'kendall'], optional
319
        method: {"pearson", "spearman", "kendall"}, by default "pearson"
320
            * pearson: measures linear relationships and requires normally \
321
                distributed and homoscedastic data.
322
            * spearman: ranked/ordinal correlation, measures monotonic relationships.
323
            * kendall: ranked/ordinal correlation, measures monotonic relationships. \
324
                Computationally more expensive but more robust in smaller dataets \
325
                than "spearman".
326
327
    cmap : str, optional
328
        The mapping from data values to color space, matplotlib colormap name or \
329
        object, or list of colors, by default "BrBG"
330
    figsize : tuple[float, float], optional
331
        Use to control the figure size, by default (12, 10)
332
    annot : bool, optional
333
        Use to show or hide annotations, by default True
334
    dev : bool, optional
335
        Display figure settings in the plot by setting dev = True. If False, the \
336
        settings are not displayed, by default False
337
338
    kwargs : optional
339
        Additional elements to control the visualization of the plot, e.g.:
340
341
            * mask: bool, default True
342
                If set to False the entire correlation matrix, including the upper \
343
                triangle is shown. Set dev = False in this case to avoid overlap.
344
            * vmax: float, default is calculated from the given correlation \
345
                coefficients.
346
                Value between -1 or vmin <= vmax <= 1, limits the range of the cbar.
347
            * vmin: float, default is calculated from the given correlation \
348
                coefficients.
349
                Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar.
350
            * linewidths: float, default 0.5
351
                Controls the line-width inbetween the squares.
352
            * annot_kws: dict, default {"size" : 10}
353
                Controls the font size of the annotations. Only available when \
354
                annot = True.
355
            * cbar_kws: dict, default {"shrink": .95, "aspect": 30}
356
                Controls the size of the colorbar.
357
            * Many more kwargs are available, i.e. "alpha" to control blending, or \
358
                options to adjust labels, ticks ...
359
360
        Kwargs can be supplied through a dictionary of key-value pairs (see above).
361
362
    Returns
363
    -------
364
    ax: matplotlib Axes
365
        Returns the Axes object with the plot for further tweaking.
366
367
    """
368
    # Validate Inputs
369
    _validate_input_range(threshold, "threshold", -1, 1)
370
    _validate_input_bool(annot, "annot")
371
    _validate_input_bool(dev, "dev")
372
373
    data = pd.DataFrame(data)
374
375
    corr = corr_mat(
376
        data,
377
        split=split,
378
        threshold=threshold,
379
        target=target,
380
        method=method,
381
        colored=False,
382
    )
383
384
    mask = np.zeros_like(corr, dtype=bool)
385
386
    if target is None:
387
        mask = np.triu(np.ones_like(corr, dtype=bool))
388
389
    vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2)
390
    vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2)
391
392
    fig, ax = plt.subplots(figsize=figsize)
393
394
    # Specify kwargs for the heatmap
395
    kwargs = {
396
        "mask": mask,
397
        "cmap": cmap,
398
        "annot": annot,
399
        "vmax": vmax,
400
        "vmin": vmin,
401
        "linewidths": 0.5,
402
        "annot_kws": {"size": 10},
403
        "cbar_kws": {"shrink": 0.95, "aspect": 30},
404
        **kwargs,
405
    }
406
407
    # Draw heatmap with mask and default settings
408
    sns.heatmap(corr, center=0, fmt=".2f", **kwargs)
409
410
    ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18})
411
412
    # Settings
413
    if dev:
414
        fig.suptitle(
415
            f"\
416
            Settings (dev-mode): \n\
417
            - split-mode: {split} \n\
418
            - threshold: {threshold} \n\
419
            - method: {method} \n\
420
            - annotations: {annot} \n\
421
            - cbar: \n\
422
                - vmax: {vmax} \n\
423
                - vmin: {vmin} \n\
424
            - linewidths: {kwargs['linewidths']} \n\
425
            - annot_kws: {kwargs['annot_kws']} \n\
426
            - cbar_kws: {kwargs['cbar_kws']}",
427
            fontsize=12,
428
            color="gray",
429
            x=0.35,
430
            y=0.85,
431
            ha="left",
432
        )
433
434
    return ax
435
436
437
def corr_interactive_plot(  # noqa: C901
438
    data: pd.DataFrame,
439
    split: Literal["pos", "neg", "high", "low"] | None = None,
440
    threshold: float = 0.0,
441
    target: pd.Series | str | None = None,
442
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
443
    cmap: str = "BrBG",
444
    figsize: tuple[float, float] = (12, 10),
445
    *,
446
    annot: bool = True,
447
    **kwargs,  # noqa: ANN003
448
) -> go.Figure:
449
    """Interactive 2D visualization of the correlation between feature-columns.
450
451
    Parameters
452
    ----------
453
    data : pd.DataFrame
454
        2D dataset that can be coerced into a Pandas DataFrame. If a
455
        Pandas DataFrame is provided, the index/column information is
456
        used to label the plots.
457
458
    split : Optional[str], optional
459
        Type of split to be performed
460
        {None, "pos", "neg", "high", "low"}, by default None
461
462
        - None: visualize all correlations between the feature-columns
463
464
        - pos: visualize all positive correlations between the
465
            feature-columns above the threshold
466
467
        - neg: visualize all negative correlations between the
468
            feature-columns below the threshold
469
470
        - high: visualize all correlations between the
471
            feature-columns for which abs(corr) > threshold is True
472
473
        - low: visualize all correlations between the
474
            feature-columns for which abs(corr) < threshold is True
475
476
    threshold : float, optional
477
        Value between 0 and 1 to set the correlation threshold,
478
        by default 0 unless split = "high" or split = "low", in
479
        which case the default is 0.3
480
481
    target : Optional[pd.Series | str], optional
482
        Specify a target for correlation. For example, the label column
483
        to generate only the correlations between each feature and the
484
        label, by default None
485
486
    method : Literal['pearson', 'spearman', 'kendall'], optional
487
        Method for correlation calculation:
488
        {"pearson", "spearman", "kendall"}, by default "pearson"
489
490
        - pearson: measures linear relationships and requires normally
491
            distributed and homoscedastic data.
492
        - spearman: ranked/ordinal correlation, measures monotonic
493
            relationships.
494
        - kendall: ranked/ordinal correlation, measures monotonic
495
            relationships. Computationally more expensive but more
496
            robust in smaller datasets than "spearman".
497
498
    cmap : str, optional
499
        The mapping from data values to color space, plotly
500
        colormap name or object, or list of colors, by default "BrBG"
501
502
    figsize : tuple[float, float], optional
503
        Use to control the figure size, by default (12, 10)
504
505
    annot : bool, optional
506
        Use to show or hide annotations, by default True
507
508
    **kwargs : optional
509
        Additional elements to control the visualization of the plot.
510
            These additional arguments will be passed to the `go.Heatmap`
511
            function in Plotly.
512
513
        Specific kwargs used in this function:
514
515
        - colorscale: str or list, optional
516
            The colorscale to be used for the heatmap. It controls the
517
            mapping of data values to colors in the heatmap.
518
519
        - zmax: float, optional
520
            The maximum value of the color scale. It limits the upper
521
            range of the colorbar displayed on the heatmap.
522
523
        - zmin: float, optional
524
            The minimum value of the color scale. It limits the lower
525
            range of the colorbar displayed on the heatmap.
526
527
        - text: pd.DataFrame, optional
528
            A DataFrame containing text to display on the heatmap. This
529
            text will be shown on the heatmap cells corresponding to the
530
            correlation values.
531
532
        - texttemplate: str, optional
533
            A text template string to format the text display on the
534
            heatmap. This allows you to customize how the text appears,
535
            including the display of the correlation values.
536
537
        - textfont: dict, optional
538
            A dictionary specifying the font properties for the text on
539
            the heatmap. You can customize the font size, color, family,
540
            etc., for the text annotations.
541
542
        - x: list, optional
543
            The list of column names for the x-axis of the heatmap. It
544
            allows you to customize the labels displayed on the x-axis.
545
546
        - y: list, optional
547
            The list of row names for the y-axis of the heatmap. It
548
            allows you to customize the labels displayed on the y-axis.
549
550
        - z: pd.DataFrame, optional
551
            The 2D array representing the correlation matrix to be
552
            visualized. This is the core data for generating the heatmap,
553
            containing the correlation values.
554
555
        - Many more kwargs are available, e.g., "hovertemplate" to control
556
            the legend hover template, or options to adjust the borderwidth
557
            and opacity of the heatmap. For a comprehensive list of
558
            available kwargs, please refer to the Plotly Heatmap documentation.
559
560
        Kwargs can be supplied through a dictionary of key-value pairs
561
        (see above) and can be found in Plotly Heatmap documentation.
562
563
    Returns
564
    -------
565
    heatmap : plotly.graph_objs._figure.Figure
566
        A Plotly Figure object representing the heatmap visualization of
567
        feature correlations.
568
569
    """
570
    # Validate Inputs
571
    _validate_input_range(threshold, "threshold", -1, 1)
572
    _validate_input_bool(annot, "annot")
573
574
    data = pd.DataFrame(data).iloc[:, ::-1]
575
576
    corr = corr_mat(
577
        data,
578
        split=split,
579
        threshold=threshold,
580
        target=target,
581
        method=method,
582
        colored=False,
583
    )
584
585
    mask = np.zeros_like(corr, dtype=bool)
586
587
    if target is None:
588
        mask = np.triu(np.ones_like(corr, dtype=bool))
589
        np.fill_diagonal(corr.to_numpy(), np.nan)
590
        corr = corr.where(mask == 1)
591
    else:
592
        corr = corr.iloc[::-1, :]
593
594
    vmax = np.round(np.nanmax(corr) - 0.05, 2)
595
    vmin = np.round(np.nanmin(corr) + 0.05, 2)
596
597
    vmax = -vmin if split == "neg" else vmax
598
    vmin = -vmax if split == "pos" else vmin
599
600
    vtext = corr.round(2).fillna("") if annot else None
601
602
    corr_columns = corr.columns
603
    corr_index = corr.index
604
605
    if isinstance(corr_columns, pd.MultiIndex):
606
        corr_columns = ["-".join(col) for col in corr.columns]
607
608
    if isinstance(corr_index, pd.MultiIndex):
609
        corr_index = ["-".join(idx) for idx in corr.index]
610
611
    # Specify kwargs for the heatmap
612
    kwargs = {
613
        "colorscale": cmap,
614
        "zmax": vmax,
615
        "zmin": vmin,
616
        "text": vtext,
617
        "texttemplate": "%{text}",
618
        "textfont": {"size": 12},
619
        "x": corr_columns,
620
        "y": corr_index,
621
        "z": corr,
622
        **kwargs,
623
    }
624
625
    # Draw heatmap with masked corr and default settings
626
    heatmap = go.Figure(
627
        data=go.Heatmap(
628
            hoverongaps=False,
629
            xgap=1,
630
            ygap=1,
631
            **kwargs,
632
        ),
633
    )
634
635
    dpi = None
636
    try:
637
        for monitor in get_monitors():
638
            if monitor.is_primary:
639
                if monitor.width_mm is None or monitor.height_mm is None:
640
                    continue
641
                dpi = monitor.width / (monitor.width_mm / 25.4)
642
                break
643
644
        if dpi is None:
645
            monitor = get_monitors()[0]
646
            if monitor.width_mm is None or monitor.height_mm is None:
647
                dpi = 96  # more or less arbitrary default value
648
            else:
649
                dpi = monitor.width / (monitor.width_mm / 25.4)
650
    except ScreenInfoError:
651
        dpi = 96
652
653
    heatmap.update_layout(
654
        title=f"Feature-correlation ({method})",
655
        title_font={"size": 24},
656
        title_x=0.5,
657
        autosize=True,
658
        width=figsize[0] * dpi,
659
        height=(figsize[1] + 1) * dpi,
660
        xaxis={"autorange": "reversed"},
661
    )
662
663
    return heatmap
664
665
666
def dist_plot(
667
    data: pd.DataFrame,
668
    mean_color: str = "orange",
669
    size: int = 3,
670
    fill_range: tuple = (0.025, 0.975),
671
    showall: bool = False,
672
    kde_kws: dict[str, Any] | None = None,
673
    rug_kws: dict[str, Any] | None = None,
674
    fill_kws: dict[str, Any] | None = None,
675
    font_kws: dict[str, Any] | None = None,
676
) -> None | Any:  # noqa: ANN401
677
    """2D visualization of the distribution of non binary numerical features.
678
679
    Parameters
680
    ----------
681
    data : pd.DataFrame
682
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
683
        is provided, the index/column information is used to label the plots
684
    mean_color : str, optional
685
        Color of the vertical line indicating the mean of the data, by default "orange"
686
    size : float, optional
687
        Controls the plot size, by default 3
688
    fill_range : tuple, optional
689
        Set the quantiles for shading. Default spans 95% of the data, which is about \
690
        two std. deviations above and below the mean, by default (0.025, 0.975)
691
    showall : bool, optional
692
        Set to True to remove the output limit of 20 plots, by default False
693
    kde_kws : dict[str, Any], optional
694
        Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \
695
        "linewidth": 1.5, "bw_adjust": 0.8}
696
    rug_kws : dict[str, Any], optional
697
        Keyword arguments for rugplot(), by default {"color": "#ff3333", \
698
        "alpha": 0.15, "lw": 3, "height": 0.075}
699
    fill_kws : dict[str, Any], optional
700
        Keyword arguments to control the fill, by default {"color": "#80d4ff", \
701
        "alpha": 0.2}
702
    font_kws : dict[str, Any], optional
703
        Keyword arguments to control the font, by default {"color":  "#111111", \
704
        "weight": "normal", "size": 11}
705
706
    Returns
707
    -------
708
    ax: matplotlib Axes
709
        Returns the Axes object with the plot for further tweaking.
710
711
    """
712
    # Validate Inputs
713
    _validate_input_range(fill_range[0], "fill_range_lower", 0, 1)
714
    _validate_input_range(fill_range[1], "fill_range_upper", 0, 1)
715
    _validate_input_smaller(fill_range[0], fill_range[1], "fill_range")
716
    _validate_input_bool(showall, "showall")
717
718
    # Handle dictionary defaults
719
    kde_kws = (
720
        {"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8} if kde_kws is None else kde_kws.copy()
721
    )
722
    rug_kws = (
723
        {"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075}
724
        if rug_kws is None
725
        else rug_kws.copy()
726
    )
727
    fill_kws = {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy()
728
    font_kws = (
729
        {"color": "#111111", "weight": "normal", "size": 11}
730
        if font_kws is None
731
        else font_kws.copy()
732
    )
733
734
    data = pd.DataFrame(data.copy()).dropna(axis=1, how="all")
735
    df = data.copy()
736
    data = data.loc[:, data.nunique() > 2]  # noqa: PLR2004
737
    if data.shape[0] > 10000:  # noqa: PLR2004
738
        data = data.sample(n=10000, random_state=408)
739
        print(
740
            "Large dataset detected, using 10000 random samples for the plots. Summary"
741
            " statistics are still based on the entire dataset.",
742
        )
743
    cols = list(data.select_dtypes(include=["number"]).columns)
744
    data = data[cols]
745
746
    if not cols:
747
        print("No columns with numeric data were detected.")
748
        return None
749
750
    if len(cols) >= 20 and not showall:  # noqa: PLR2004
751
        print(
752
            "Note: The number of non binary numerical features is very large "
753
            f"({len(cols)}), please consider splitting the data. Showing plots for "
754
            "the first 20 numerical features. Override this by setting showall=True.",
755
        )
756
        cols = cols[:20]
757
    if not cols:
758
        print("No columns with numeric data were detected.")
759
        return None
760
761
    for col in cols:
762
        col_data = data[col].dropna(axis=0)
763
        col_df = df[col].dropna(axis=0)
764
765
        g = sns.displot(
766
            col_data,
767
            kind="kde",
768
            rug=True,
769
            height=size,
770
            aspect=5,
771
            legend=False,
772
            rug_kws=rug_kws,
773
            **kde_kws,
774
        )
775
776
        # Vertical lines and fill
777
        x, y = g.axes[0, 0].lines[0].get_xydata().T
778
        g.axes[0, 0].fill_between(
779
            x,
780
            y,
781
            where=(
782
                (x >= np.quantile(col_df, fill_range[0]))
783
                & (x <= np.quantile(col_df, fill_range[1]))
784
            ),
785
            label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%",
786
            **fill_kws,
787
        )
788
789
        mean = np.mean(col_df)
790
        std = scipy.stats.tstd(col_df)
791
        g.axes[0, 0].vlines(
792
            x=mean,
793
            ymin=0,
794
            ymax=np.interp(mean, x, y),
795
            ls="dotted",
796
            color=mean_color,
797
            lw=2,
798
            label="mean",
799
        )
800
        g.axes[0, 0].vlines(
801
            x=np.median(col_df),
802
            ymin=0,
803
            ymax=np.interp(np.median(col_df), x, y),
804
            ls=":",
805
            color=".3",
806
            label="median",
807
        )
808
        g.axes[0, 0].vlines(
809
            x=[mean - std, mean + std],
810
            ymin=0,
811
            ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)],
812
            ls=":",
813
            color=".5",
814
            label="\u03bc \u00b1 \u03c3",
815
        )
816
817
        g.axes[0, 0].set_ylim(0)
818
        g.axes[0, 0].set_xlim(
819
            g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05,
820
            g.axes[0, 0].get_xlim()[1] * 1.03,
821
        )
822
823
        # Annotations and legend
824
        g.axes[0, 0].text(
825
            0.005,
826
            0.9,
827
            f"Mean: {mean:.2f}",
828
            fontdict=font_kws,
829
            transform=g.axes[0, 0].transAxes,
830
        )
831
        g.axes[0, 0].text(
832
            0.005,
833
            0.7,
834
            f"Std. dev: {std:.2f}",
835
            fontdict=font_kws,
836
            transform=g.axes[0, 0].transAxes,
837
        )
838
        g.axes[0, 0].text(
839
            0.005,
840
            0.5,
841
            f"Skew: {scipy.stats.skew(col_df):.2f}",
842
            fontdict=font_kws,
843
            transform=g.axes[0, 0].transAxes,
844
        )
845
        g.axes[0, 0].text(
846
            0.005,
847
            0.3,
848
            f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}",  # Excess Kurtosis
849
            fontdict=font_kws,
850
            transform=g.axes[0, 0].transAxes,
851
        )
852
        g.axes[0, 0].text(
853
            0.005,
854
            0.1,
855
            f"Count: {len(col_df)}",
856
            fontdict=font_kws,
857
            transform=g.axes[0, 0].transAxes,
858
        )
859
        g.axes[0, 0].legend(loc="upper right")
860
861
    return g.axes[0, 0]
0 ignored issues
show
introduced by
The variable g does not seem to be defined in case the for loop on line 761 is not entered. Are you sure this can never be the case?
Loading history...
862
863
864
def missingval_plot(  # noqa: PLR0915
865
    data: pd.DataFrame,
866
    cmap: str = "PuBuGn",
867
    figsize: tuple = (20, 20),
868
    sort: bool = False,
869
    spine_color: str = "#EEEEEE",
870
) -> GridSpec:
871
    """Two-dimensional visualization of the missing values in a dataset.
872
873
    Parameters
874
    ----------
875
    data : pd.DataFrame
876
        2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \
877
        is provided, the index/column information is used to label the plots
878
    cmap : str, optional
879
        Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \
880
        found in the matplotlib documentation, by default "PuBuGn"
881
    figsize : tuple, optional
882
        Use to control the figure size, by default (20, 20)
883
    sort : bool, optional
884
        Sort columns based on missing values in descending order and drop columns \
885
        without any missing values, by default False
886
    spine_color : str, optional
887
        Set to "None" to hide the spines on all plots or use any valid matplotlib \
888
        color argument, by default "#EEEEEE"
889
890
    Returns
891
    -------
892
    GridSpec
893
        gs: Figure with array of Axes objects
894
895
    """
896
    # Validate Inputs
897
    _validate_input_bool(sort, "sort")
898
899
    data = pd.DataFrame(data)
900
901
    if sort:
902
        mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False)
903
        final_cols = (
904
            mv_cols_sorted.drop(
905
                mv_cols_sorted[mv_cols_sorted.to_numpy() == 0].keys().tolist(),
906
            )
907
            .keys()
908
            .tolist()
909
        )
910
        data = data[final_cols]
911
        print("Displaying only columns with missing values.")
912
913
    # Identify missing values
914
    mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values()
915
    total_datapoints = data.shape[0] * data.shape[1]
916
917
    if mv_total == 0:
918
        print("No missing values found in the dataset.")
919
        return None
920
921
    # Create figure and axes
922
    fig = plt.figure(figsize=figsize)
923
    gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05)
924
    ax1 = fig.add_subplot(gs[:1, :5])
925
    ax2 = fig.add_subplot(gs[1:, :5])
926
    ax3 = fig.add_subplot(gs[:1, 5:])
927
    ax4 = fig.add_subplot(gs[1:, 5:])
928
929
    # ax1 - Barplot
930
    colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols))  # color bars by height
931
    ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors)
932
    ax1.get_xaxis().set_visible(False)
933
    ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5))
934
    ax1.set_ylim(0, np.max(mv_cols_ratio) * 100)
935
    ax1.grid(linestyle=":", linewidth=1)
936
    ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=1))
937
    ax1.tick_params(axis="y", colors="#111111", length=1)
938
939
    # annotate values on top of the bars
940
    for rect, label in zip(ax1.patches, mv_cols, strict=True):
941
        height = rect.get_height()
942
        ax1.text(
943
            rect.get_x() + rect.get_width() / 2,
944
            height + max(np.log(1 + height / 6), 0.075),
945
            label,
946
            ha="center",
947
            va="bottom",
948
            rotation=90,
949
            alpha=0.5,
950
            fontsize="11",
951
        )
952
953
    ax1.set_frame_on(True)
954
    for spine in ax1.spines.values():
955
        spine.set_visible(True)
956
        spine.set_color(spine_color)
957
    ax1.spines["top"].set_color(None)
958
959
    # ax2 - Heatmap
960
    sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2)
961
    ax2.set_yticks(np.round(ax2.get_yticks()[::5], -1))
962
    ax2.set_yticklabels(ax2.get_yticks())
963
    ax2.set_xticklabels(
964
        ax2.get_xticklabels(),
965
        horizontalalignment="center",
966
        fontweight="light",
967
        fontsize="12",
968
    )
969
    ax2.tick_params(length=1, colors="#111111")
970
    for spine in ax2.spines.values():
971
        spine.set_visible(True)
972
        spine.set_color(spine_color)
973
974
    # ax3 - Summary
975
    fontax3 = {"color": "#111111", "weight": "normal", "size": 14}
976
    ax3.get_xaxis().set_visible(False)
977
    ax3.get_yaxis().set_visible(False)
978
    ax3.set(frame_on=False)
979
980
    ax3.text(
981
        0.025,
982
        0.875,
983
        f"Total: {np.round(total_datapoints/1000,1)}K",
984
        transform=ax3.transAxes,
985
        fontdict=fontax3,
986
    )
987
    ax3.text(
988
        0.025,
989
        0.675,
990
        f"Missing: {np.round(mv_total/1000,1)}K",
991
        transform=ax3.transAxes,
992
        fontdict=fontax3,
993
    )
994
    ax3.text(
995
        0.025,
996
        0.475,
997
        f"Relative: {np.round(mv_total/total_datapoints*100,1)}%",
998
        transform=ax3.transAxes,
999
        fontdict=fontax3,
1000
    )
1001
    ax3.text(
1002
        0.025,
1003
        0.275,
1004
        f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%",
1005
        transform=ax3.transAxes,
1006
        fontdict=fontax3,
1007
    )
1008
    ax3.text(
1009
        0.025,
1010
        0.075,
1011
        f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%",
1012
        transform=ax3.transAxes,
1013
        fontdict=fontax3,
1014
    )
1015
1016
    # ax4 - Scatter plot
1017
    ax4.get_yaxis().set_visible(False)
1018
    for spine in ax4.spines.values():
1019
        spine.set_color(spine_color)
1020
    ax4.tick_params(axis="x", colors="#111111", length=1)
1021
1022
    ax4.scatter(
1023
        mv_rows,
1024
        range(len(mv_rows)),
1025
        s=mv_rows,
1026
        c=mv_rows,
1027
        cmap=cmap,
1028
        marker=".",
1029
        vmin=1,
1030
    )
1031
    ax4.set_ylim((0, len(mv_rows))[::-1])  # limit and invert y-axis
1032
    ax4.set_xlim(0, max(mv_rows) + 0.5)
1033
    ax4.grid(linestyle=":", linewidth=1)
1034
1035
    gs.figure.suptitle(
1036
        "Missing value plot",
1037
        x=0.45,
1038
        y=0.94,
1039
        fontsize=18,
1040
        color="#111111",
1041
    )
1042
1043
    return gs
1044