1
|
|
|
"""Functions for descriptive analytics. |
2
|
|
|
|
3
|
|
|
:author: Andreas Kanz |
4
|
|
|
|
5
|
|
|
""" |
6
|
|
|
|
7
|
|
|
from __future__ import annotations |
8
|
|
|
|
9
|
|
|
from typing import Any |
10
|
|
|
from typing import Literal |
11
|
|
|
|
12
|
|
|
import matplotlib.pyplot as plt |
13
|
|
|
import numpy as np |
14
|
|
|
import pandas as pd |
15
|
|
|
import plotly.graph_objects as go |
16
|
|
|
import scipy |
17
|
|
|
import seaborn as sns |
18
|
|
|
from matplotlib import ticker |
19
|
|
|
from matplotlib.colors import LinearSegmentedColormap |
20
|
|
|
from matplotlib.colors import to_rgb |
21
|
|
|
from matplotlib.gridspec import GridSpec # noqa: TCH002 |
22
|
|
|
from screeninfo import get_monitors |
23
|
|
|
from screeninfo import ScreenInfoError |
24
|
|
|
|
25
|
|
|
from klib.utils import _corr_selector |
26
|
|
|
from klib.utils import _missing_vals |
27
|
|
|
from klib.utils import _validate_input_bool |
28
|
|
|
from klib.utils import _validate_input_int |
29
|
|
|
from klib.utils import _validate_input_num_data |
30
|
|
|
from klib.utils import _validate_input_range |
31
|
|
|
from klib.utils import _validate_input_smaller |
32
|
|
|
from klib.utils import _validate_input_sum_larger |
33
|
|
|
|
34
|
|
|
__all__ = [ |
35
|
|
|
"cat_plot", |
36
|
|
|
"corr_interactive_plot", |
37
|
|
|
"corr_mat", |
38
|
|
|
"corr_plot", |
39
|
|
|
"dist_plot", |
40
|
|
|
"missingval_plot", |
41
|
|
|
] |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
def cat_plot( # noqa: C901, PLR0915 |
45
|
|
|
data: pd.DataFrame, |
46
|
|
|
figsize: tuple[float, float] = (18, 18), |
47
|
|
|
top: int = 3, |
48
|
|
|
bottom: int = 3, |
49
|
|
|
bar_color_top: str = "#5ab4ac", |
50
|
|
|
bar_color_bottom: str = "#d8b365", |
51
|
|
|
) -> GridSpec: |
52
|
|
|
"""Two-dimensional visualization of number and frequency of categorical features. |
53
|
|
|
|
54
|
|
|
Parameters |
55
|
|
|
---------- |
56
|
|
|
data : pd.DataFrame |
57
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
58
|
|
|
is provided, the index/column information is used to label the plots |
59
|
|
|
figsize : tuple[float, float], optional |
60
|
|
|
Use to control the figure size, by default (18, 18) |
61
|
|
|
top : int, optional |
62
|
|
|
Show the "top" most frequent values in a column, by default 3 |
63
|
|
|
bottom : int, optional |
64
|
|
|
Show the "bottom" most frequent values in a column, by default 3 |
65
|
|
|
bar_color_top : str, optional |
66
|
|
|
Use to control the color of the bars indicating the most common values, by \ |
67
|
|
|
default "#5ab4ac" |
68
|
|
|
bar_color_bottom : str, optional |
69
|
|
|
Use to control the color of the bars indicating the least common values, by \ |
70
|
|
|
default "#d8b365" |
71
|
|
|
|
72
|
|
|
Returns |
73
|
|
|
------- |
74
|
|
|
Gridspec |
75
|
|
|
gs: Figure with array of Axes objects |
76
|
|
|
|
77
|
|
|
""" |
78
|
|
|
# Validate Inputs |
79
|
|
|
_validate_input_int(top, "top") |
80
|
|
|
_validate_input_int(bottom, "bottom") |
81
|
|
|
_validate_input_sum_larger(1, "top and bottom", top, bottom) |
82
|
|
|
|
83
|
|
|
data = pd.DataFrame(data).copy() |
84
|
|
|
cols = data.select_dtypes(exclude=["number"]).columns.tolist() |
85
|
|
|
data = data[cols] |
86
|
|
|
|
87
|
|
|
if len(cols) == 0: |
88
|
|
|
print("No columns with categorical data were detected.") |
89
|
|
|
return None |
90
|
|
|
|
91
|
|
|
for col in data.columns: |
92
|
|
|
if data[col].dtype.name in ("category", "string"): |
93
|
|
|
data[col] = data[col].astype("object") |
94
|
|
|
|
95
|
|
|
fig = plt.figure(figsize=figsize) |
96
|
|
|
gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21) |
97
|
|
|
|
98
|
|
|
for count, col in enumerate(cols): |
99
|
|
|
n_unique = data[col].nunique(dropna=True) |
100
|
|
|
value_counts = data[col].value_counts() |
101
|
|
|
lim_top, lim_bot = top, bottom |
102
|
|
|
|
103
|
|
|
if n_unique < top + bottom: |
104
|
|
|
if bottom > top: |
105
|
|
|
lim_top = min(int(n_unique // 2), top) |
106
|
|
|
lim_bot = n_unique - lim_top |
107
|
|
|
else: |
108
|
|
|
lim_bot = min(int(n_unique // 2), bottom) |
109
|
|
|
lim_top = n_unique - lim_bot |
110
|
|
|
|
111
|
|
|
value_counts_top = value_counts[:lim_top] |
112
|
|
|
value_counts_idx_top = value_counts_top.index.tolist() |
113
|
|
|
value_counts_bot = value_counts[-lim_bot:] if lim_bot > 0 else pd.DataFrame() |
114
|
|
|
value_counts_idx_bot = value_counts_bot.index.tolist() |
115
|
|
|
|
116
|
|
|
if top == 0: |
117
|
|
|
value_counts_top = value_counts_idx_top = [] |
118
|
|
|
|
119
|
|
|
if bottom == 0: |
120
|
|
|
value_counts_bot = value_counts_idx_bot = [] |
121
|
|
|
|
122
|
|
|
data.loc[data[col].isin(value_counts_idx_top), col] = 10 |
123
|
|
|
data.loc[data[col].isin(value_counts_idx_bot), col] = 0 |
124
|
|
|
data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5 # noqa: PLR2004 |
125
|
|
|
data[col] = data[col].rolling(2, min_periods=1).mean() |
126
|
|
|
|
127
|
|
|
value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top] |
128
|
|
|
value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot] |
129
|
|
|
sum_top = sum(value_counts_top) |
130
|
|
|
sum_bot = sum(value_counts_bot) |
131
|
|
|
|
132
|
|
|
# Barcharts |
133
|
|
|
ax_top = fig.add_subplot(gs[:1, count : count + 1]) |
134
|
|
|
ax_top.bar( |
135
|
|
|
value_counts_idx_top, |
136
|
|
|
value_counts_top, |
137
|
|
|
color=bar_color_top, |
138
|
|
|
width=0.85, |
139
|
|
|
) |
140
|
|
|
ax_top.bar( |
141
|
|
|
value_counts_idx_bot, |
142
|
|
|
value_counts_bot, |
143
|
|
|
color=bar_color_bottom, |
144
|
|
|
width=0.85, |
145
|
|
|
) |
146
|
|
|
ax_top.set(frame_on=False) |
147
|
|
|
ax_top.tick_params(axis="x", labelrotation=90) |
148
|
|
|
|
149
|
|
|
# Summary stats |
150
|
|
|
ax_bottom = fig.add_subplot(gs[1:2, count : count + 1]) |
151
|
|
|
plt.subplots_adjust(hspace=0.075) |
152
|
|
|
ax_bottom.get_yaxis().set_visible(False) |
153
|
|
|
ax_bottom.get_xaxis().set_visible(False) |
154
|
|
|
ax_bottom.set(frame_on=False) |
155
|
|
|
ax_bottom.text( |
156
|
|
|
0, |
157
|
|
|
0, |
158
|
|
|
f"Unique values: {n_unique}\n\n" |
159
|
|
|
f"Top {lim_top}: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n" |
160
|
|
|
f"Bot {lim_bot}: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)", |
161
|
|
|
transform=ax_bottom.transAxes, |
162
|
|
|
color="#111111", |
163
|
|
|
fontsize=11, |
164
|
|
|
) |
165
|
|
|
|
166
|
|
|
# Heatmap |
167
|
|
|
color_bot_rgb = to_rgb(bar_color_bottom) |
168
|
|
|
color_white = to_rgb("#FFFFFF") |
169
|
|
|
color_top_rgb = to_rgb(bar_color_top) |
170
|
|
|
cat_plot_cmap = LinearSegmentedColormap.from_list( |
171
|
|
|
"cat_plot_cmap", |
172
|
|
|
[color_bot_rgb, color_white, color_top_rgb], |
173
|
|
|
N=200, |
174
|
|
|
) |
175
|
|
|
ax_hm = fig.add_subplot(gs[2:, :]) |
176
|
|
|
sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm) |
177
|
|
|
ax_hm.set_yticks(np.round(ax_hm.get_yticks()[::5], -1)) |
178
|
|
|
ax_hm.set_yticklabels(ax_hm.get_yticks()) |
179
|
|
|
ax_hm.set_xticklabels( |
180
|
|
|
ax_hm.get_xticklabels(), |
181
|
|
|
horizontalalignment="center", |
182
|
|
|
fontweight="light", |
183
|
|
|
fontsize="medium", |
184
|
|
|
) |
185
|
|
|
ax_hm.tick_params(length=1, colors="#111111") |
186
|
|
|
gs.figure.suptitle( |
187
|
|
|
"Categorical data plot", |
188
|
|
|
x=0.5, |
189
|
|
|
y=0.91, |
190
|
|
|
fontsize=18, |
191
|
|
|
color="#111111", |
192
|
|
|
) |
193
|
|
|
|
194
|
|
|
return gs |
195
|
|
|
|
196
|
|
|
|
197
|
|
|
def corr_mat( |
198
|
|
|
data: pd.DataFrame, |
199
|
|
|
split: Literal["pos", "neg", "high", "low"] | None = None, |
200
|
|
|
threshold: float = 0, |
201
|
|
|
target: pd.DataFrame | pd.Series | np.ndarray | str | None = None, |
202
|
|
|
method: Literal["pearson", "spearman", "kendall"] = "pearson", |
203
|
|
|
*, |
204
|
|
|
colored: bool = True, |
205
|
|
|
) -> pd.DataFrame | pd.Series: |
206
|
|
|
"""Return a color-encoded correlation matrix. |
207
|
|
|
|
208
|
|
|
Parameters |
209
|
|
|
---------- |
210
|
|
|
data : pd.DataFrame |
211
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
212
|
|
|
is provided, the index/column information is used to label the plots |
213
|
|
|
split : Optional[Literal['pos', 'neg', 'high', 'low']], optional |
214
|
|
|
Type of split to be performed, by default None |
215
|
|
|
{None, "pos", "neg", "high", "low"} |
216
|
|
|
threshold : float, optional |
217
|
|
|
Value between 0 and 1 to set the correlation threshold, by default 0 unless \ |
218
|
|
|
split = "high" or split = "low", in which case default is 0.3 |
219
|
|
|
target : Optional[pd.DataFrame | str], optional |
220
|
|
|
Specify target for correlation. E.g. label column to generate only the \ |
221
|
|
|
correlations between each feature and the label, by default None |
222
|
|
|
method : Literal['pearson', 'spearman', 'kendall'], optional |
223
|
|
|
method: {"pearson", "spearman", "kendall"}, by default "pearson" |
224
|
|
|
* pearson: measures linear relationships and requires normally distributed \ |
225
|
|
|
and homoscedastic data. |
226
|
|
|
* spearman: ranked/ordinal correlation, measures monotonic relationships. |
227
|
|
|
* kendall: ranked/ordinal correlation, measures monotonic relationships. \ |
228
|
|
|
Computationally more expensive but more robust in smaller dataets than \ |
229
|
|
|
"spearman" |
230
|
|
|
colored : bool, optional |
231
|
|
|
If True the negative values in the correlation matrix are colored in red, by \ |
232
|
|
|
default True |
233
|
|
|
|
234
|
|
|
Returns |
235
|
|
|
------- |
236
|
|
|
pd.DataFrame | pd.Styler |
237
|
|
|
If colored = True - corr: Pandas Styler object |
238
|
|
|
If colored = False - corr: Pandas DataFrame |
239
|
|
|
|
240
|
|
|
""" |
241
|
|
|
# Validate Inputs |
242
|
|
|
_validate_input_range(threshold, "threshold", -1, 1) |
243
|
|
|
_validate_input_bool(colored, "colored") |
244
|
|
|
|
245
|
|
|
def color_negative_red(val: float) -> str: |
246
|
|
|
color = "#FF3344" if val < 0 else None |
247
|
|
|
return f"color: {color}" |
248
|
|
|
|
249
|
|
|
data = pd.DataFrame(data) |
250
|
|
|
|
251
|
|
|
_validate_input_num_data(data, "data") |
252
|
|
|
|
253
|
|
|
if isinstance(target, (str, list, pd.Series, np.ndarray)): |
254
|
|
|
target_data = [] |
255
|
|
|
if isinstance(target, str): |
256
|
|
|
target_data = data[target] |
257
|
|
|
data = data.drop(target, axis=1) |
258
|
|
|
|
259
|
|
|
elif isinstance(target, (list, pd.Series, np.ndarray)): |
260
|
|
|
target_data = pd.Series(target) |
261
|
|
|
target = target_data.name |
262
|
|
|
|
263
|
|
|
corr = pd.DataFrame( |
264
|
|
|
data.corrwith(target_data, method=method, numeric_only=True), |
265
|
|
|
) |
266
|
|
|
corr = corr.sort_values(corr.columns[0], ascending=False) |
267
|
|
|
corr.columns = [target] |
268
|
|
|
|
269
|
|
|
else: |
270
|
|
|
corr = data.corr(method=method, numeric_only=True) |
271
|
|
|
|
272
|
|
|
corr = _corr_selector(corr, split=split, threshold=threshold) |
273
|
|
|
|
274
|
|
|
if colored: |
275
|
|
|
return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-") |
276
|
|
|
return corr |
277
|
|
|
|
278
|
|
|
|
279
|
|
|
def corr_plot( |
280
|
|
|
data: pd.DataFrame, |
281
|
|
|
split: Literal["pos", "neg", "high", "low"] | None = None, |
282
|
|
|
threshold: float = 0, |
283
|
|
|
target: pd.Series | str | None = None, |
284
|
|
|
method: Literal["pearson", "spearman", "kendall"] = "pearson", |
285
|
|
|
cmap: str = "BrBG", |
286
|
|
|
figsize: tuple[float, float] = (12, 10), |
287
|
|
|
*, |
288
|
|
|
annot: bool = True, |
289
|
|
|
dev: bool = False, |
290
|
|
|
**kwargs, # noqa: ANN003 |
291
|
|
|
) -> plt.Axes: |
292
|
|
|
"""2D visualization of the correlation between feature-columns excluding NA values. |
293
|
|
|
|
294
|
|
|
Parameters |
295
|
|
|
---------- |
296
|
|
|
data : pd.DataFrame |
297
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
298
|
|
|
is provided, the index/column information is used to label the plots |
299
|
|
|
split : Optional[str], optional |
300
|
|
|
Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \ |
301
|
|
|
None |
302
|
|
|
* None: visualize all correlations between the feature-columns |
303
|
|
|
* pos: visualize all positive correlations between the feature-columns \ |
304
|
|
|
above the threshold |
305
|
|
|
* neg: visualize all negative correlations between the feature-columns \ |
306
|
|
|
below the threshold |
307
|
|
|
* high: visualize all correlations between the feature-columns for \ |
308
|
|
|
which abs (corr) > threshold is True |
309
|
|
|
* low: visualize all correlations between the feature-columns for which \ |
310
|
|
|
abs(corr) < threshold is True |
311
|
|
|
|
312
|
|
|
threshold : float, optional |
313
|
|
|
Value between 0 and 1 to set the correlation threshold, by default 0 unless \ |
314
|
|
|
split = "high" or split = "low", in which case default is 0.3 |
315
|
|
|
target : Optional[pd.Series | str], optional |
316
|
|
|
Specify target for correlation. E.g. label column to generate only the \ |
317
|
|
|
correlations between each feature and the label, by default None |
318
|
|
|
method : Literal['pearson', 'spearman', 'kendall'], optional |
319
|
|
|
method: {"pearson", "spearman", "kendall"}, by default "pearson" |
320
|
|
|
* pearson: measures linear relationships and requires normally \ |
321
|
|
|
distributed and homoscedastic data. |
322
|
|
|
* spearman: ranked/ordinal correlation, measures monotonic relationships. |
323
|
|
|
* kendall: ranked/ordinal correlation, measures monotonic relationships. \ |
324
|
|
|
Computationally more expensive but more robust in smaller dataets \ |
325
|
|
|
than "spearman". |
326
|
|
|
|
327
|
|
|
cmap : str, optional |
328
|
|
|
The mapping from data values to color space, matplotlib colormap name or \ |
329
|
|
|
object, or list of colors, by default "BrBG" |
330
|
|
|
figsize : tuple[float, float], optional |
331
|
|
|
Use to control the figure size, by default (12, 10) |
332
|
|
|
annot : bool, optional |
333
|
|
|
Use to show or hide annotations, by default True |
334
|
|
|
dev : bool, optional |
335
|
|
|
Display figure settings in the plot by setting dev = True. If False, the \ |
336
|
|
|
settings are not displayed, by default False |
337
|
|
|
|
338
|
|
|
kwargs : optional |
339
|
|
|
Additional elements to control the visualization of the plot, e.g.: |
340
|
|
|
|
341
|
|
|
* mask: bool, default True |
342
|
|
|
If set to False the entire correlation matrix, including the upper \ |
343
|
|
|
triangle is shown. Set dev = False in this case to avoid overlap. |
344
|
|
|
* vmax: float, default is calculated from the given correlation \ |
345
|
|
|
coefficients. |
346
|
|
|
Value between -1 or vmin <= vmax <= 1, limits the range of the cbar. |
347
|
|
|
* vmin: float, default is calculated from the given correlation \ |
348
|
|
|
coefficients. |
349
|
|
|
Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar. |
350
|
|
|
* linewidths: float, default 0.5 |
351
|
|
|
Controls the line-width inbetween the squares. |
352
|
|
|
* annot_kws: dict, default {"size" : 10} |
353
|
|
|
Controls the font size of the annotations. Only available when \ |
354
|
|
|
annot = True. |
355
|
|
|
* cbar_kws: dict, default {"shrink": .95, "aspect": 30} |
356
|
|
|
Controls the size of the colorbar. |
357
|
|
|
* Many more kwargs are available, i.e. "alpha" to control blending, or \ |
358
|
|
|
options to adjust labels, ticks ... |
359
|
|
|
|
360
|
|
|
Kwargs can be supplied through a dictionary of key-value pairs (see above). |
361
|
|
|
|
362
|
|
|
Returns |
363
|
|
|
------- |
364
|
|
|
ax: matplotlib Axes |
365
|
|
|
Returns the Axes object with the plot for further tweaking. |
366
|
|
|
|
367
|
|
|
""" |
368
|
|
|
# Validate Inputs |
369
|
|
|
_validate_input_range(threshold, "threshold", -1, 1) |
370
|
|
|
_validate_input_bool(annot, "annot") |
371
|
|
|
_validate_input_bool(dev, "dev") |
372
|
|
|
|
373
|
|
|
data = pd.DataFrame(data) |
374
|
|
|
|
375
|
|
|
corr = corr_mat( |
376
|
|
|
data, |
377
|
|
|
split=split, |
378
|
|
|
threshold=threshold, |
379
|
|
|
target=target, |
380
|
|
|
method=method, |
381
|
|
|
colored=False, |
382
|
|
|
) |
383
|
|
|
|
384
|
|
|
mask = np.zeros_like(corr, dtype=bool) |
385
|
|
|
|
386
|
|
|
if target is None: |
387
|
|
|
mask = np.triu(np.ones_like(corr, dtype=bool)) |
388
|
|
|
|
389
|
|
|
vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2) |
390
|
|
|
vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2) |
391
|
|
|
|
392
|
|
|
fig, ax = plt.subplots(figsize=figsize) |
393
|
|
|
|
394
|
|
|
# Specify kwargs for the heatmap |
395
|
|
|
kwargs = { |
396
|
|
|
"mask": mask, |
397
|
|
|
"cmap": cmap, |
398
|
|
|
"annot": annot, |
399
|
|
|
"vmax": vmax, |
400
|
|
|
"vmin": vmin, |
401
|
|
|
"linewidths": 0.5, |
402
|
|
|
"annot_kws": {"size": 10}, |
403
|
|
|
"cbar_kws": {"shrink": 0.95, "aspect": 30}, |
404
|
|
|
**kwargs, |
405
|
|
|
} |
406
|
|
|
|
407
|
|
|
# Draw heatmap with mask and default settings |
408
|
|
|
sns.heatmap(corr, center=0, fmt=".2f", **kwargs) |
409
|
|
|
|
410
|
|
|
ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18}) |
411
|
|
|
|
412
|
|
|
# Settings |
413
|
|
|
if dev: |
414
|
|
|
fig.suptitle( |
415
|
|
|
f"\ |
416
|
|
|
Settings (dev-mode): \n\ |
417
|
|
|
- split-mode: {split} \n\ |
418
|
|
|
- threshold: {threshold} \n\ |
419
|
|
|
- method: {method} \n\ |
420
|
|
|
- annotations: {annot} \n\ |
421
|
|
|
- cbar: \n\ |
422
|
|
|
- vmax: {vmax} \n\ |
423
|
|
|
- vmin: {vmin} \n\ |
424
|
|
|
- linewidths: {kwargs['linewidths']} \n\ |
425
|
|
|
- annot_kws: {kwargs['annot_kws']} \n\ |
426
|
|
|
- cbar_kws: {kwargs['cbar_kws']}", |
427
|
|
|
fontsize=12, |
428
|
|
|
color="gray", |
429
|
|
|
x=0.35, |
430
|
|
|
y=0.85, |
431
|
|
|
ha="left", |
432
|
|
|
) |
433
|
|
|
|
434
|
|
|
return ax |
435
|
|
|
|
436
|
|
|
|
437
|
|
|
def corr_interactive_plot( # noqa: C901 |
438
|
|
|
data: pd.DataFrame, |
439
|
|
|
split: Literal["pos", "neg", "high", "low"] | None = None, |
440
|
|
|
threshold: float = 0.0, |
441
|
|
|
target: pd.Series | str | None = None, |
442
|
|
|
method: Literal["pearson", "spearman", "kendall"] = "pearson", |
443
|
|
|
cmap: str = "BrBG", |
444
|
|
|
figsize: tuple[float, float] = (12, 10), |
445
|
|
|
*, |
446
|
|
|
annot: bool = True, |
447
|
|
|
**kwargs, # noqa: ANN003 |
448
|
|
|
) -> go.Figure: |
449
|
|
|
"""Interactive 2D visualization of the correlation between feature-columns. |
450
|
|
|
|
451
|
|
|
Parameters |
452
|
|
|
---------- |
453
|
|
|
data : pd.DataFrame |
454
|
|
|
2D dataset that can be coerced into a Pandas DataFrame. If a |
455
|
|
|
Pandas DataFrame is provided, the index/column information is |
456
|
|
|
used to label the plots. |
457
|
|
|
|
458
|
|
|
split : Optional[str], optional |
459
|
|
|
Type of split to be performed |
460
|
|
|
{None, "pos", "neg", "high", "low"}, by default None |
461
|
|
|
|
462
|
|
|
- None: visualize all correlations between the feature-columns |
463
|
|
|
|
464
|
|
|
- pos: visualize all positive correlations between the |
465
|
|
|
feature-columns above the threshold |
466
|
|
|
|
467
|
|
|
- neg: visualize all negative correlations between the |
468
|
|
|
feature-columns below the threshold |
469
|
|
|
|
470
|
|
|
- high: visualize all correlations between the |
471
|
|
|
feature-columns for which abs(corr) > threshold is True |
472
|
|
|
|
473
|
|
|
- low: visualize all correlations between the |
474
|
|
|
feature-columns for which abs(corr) < threshold is True |
475
|
|
|
|
476
|
|
|
threshold : float, optional |
477
|
|
|
Value between 0 and 1 to set the correlation threshold, |
478
|
|
|
by default 0 unless split = "high" or split = "low", in |
479
|
|
|
which case the default is 0.3 |
480
|
|
|
|
481
|
|
|
target : Optional[pd.Series | str], optional |
482
|
|
|
Specify a target for correlation. For example, the label column |
483
|
|
|
to generate only the correlations between each feature and the |
484
|
|
|
label, by default None |
485
|
|
|
|
486
|
|
|
method : Literal['pearson', 'spearman', 'kendall'], optional |
487
|
|
|
Method for correlation calculation: |
488
|
|
|
{"pearson", "spearman", "kendall"}, by default "pearson" |
489
|
|
|
|
490
|
|
|
- pearson: measures linear relationships and requires normally |
491
|
|
|
distributed and homoscedastic data. |
492
|
|
|
- spearman: ranked/ordinal correlation, measures monotonic |
493
|
|
|
relationships. |
494
|
|
|
- kendall: ranked/ordinal correlation, measures monotonic |
495
|
|
|
relationships. Computationally more expensive but more |
496
|
|
|
robust in smaller datasets than "spearman". |
497
|
|
|
|
498
|
|
|
cmap : str, optional |
499
|
|
|
The mapping from data values to color space, plotly |
500
|
|
|
colormap name or object, or list of colors, by default "BrBG" |
501
|
|
|
|
502
|
|
|
figsize : tuple[float, float], optional |
503
|
|
|
Use to control the figure size, by default (12, 10) |
504
|
|
|
|
505
|
|
|
annot : bool, optional |
506
|
|
|
Use to show or hide annotations, by default True |
507
|
|
|
|
508
|
|
|
**kwargs : optional |
509
|
|
|
Additional elements to control the visualization of the plot. |
510
|
|
|
These additional arguments will be passed to the `go.Heatmap` |
511
|
|
|
function in Plotly. |
512
|
|
|
|
513
|
|
|
Specific kwargs used in this function: |
514
|
|
|
|
515
|
|
|
- colorscale: str or list, optional |
516
|
|
|
The colorscale to be used for the heatmap. It controls the |
517
|
|
|
mapping of data values to colors in the heatmap. |
518
|
|
|
|
519
|
|
|
- zmax: float, optional |
520
|
|
|
The maximum value of the color scale. It limits the upper |
521
|
|
|
range of the colorbar displayed on the heatmap. |
522
|
|
|
|
523
|
|
|
- zmin: float, optional |
524
|
|
|
The minimum value of the color scale. It limits the lower |
525
|
|
|
range of the colorbar displayed on the heatmap. |
526
|
|
|
|
527
|
|
|
- text: pd.DataFrame, optional |
528
|
|
|
A DataFrame containing text to display on the heatmap. This |
529
|
|
|
text will be shown on the heatmap cells corresponding to the |
530
|
|
|
correlation values. |
531
|
|
|
|
532
|
|
|
- texttemplate: str, optional |
533
|
|
|
A text template string to format the text display on the |
534
|
|
|
heatmap. This allows you to customize how the text appears, |
535
|
|
|
including the display of the correlation values. |
536
|
|
|
|
537
|
|
|
- textfont: dict, optional |
538
|
|
|
A dictionary specifying the font properties for the text on |
539
|
|
|
the heatmap. You can customize the font size, color, family, |
540
|
|
|
etc., for the text annotations. |
541
|
|
|
|
542
|
|
|
- x: list, optional |
543
|
|
|
The list of column names for the x-axis of the heatmap. It |
544
|
|
|
allows you to customize the labels displayed on the x-axis. |
545
|
|
|
|
546
|
|
|
- y: list, optional |
547
|
|
|
The list of row names for the y-axis of the heatmap. It |
548
|
|
|
allows you to customize the labels displayed on the y-axis. |
549
|
|
|
|
550
|
|
|
- z: pd.DataFrame, optional |
551
|
|
|
The 2D array representing the correlation matrix to be |
552
|
|
|
visualized. This is the core data for generating the heatmap, |
553
|
|
|
containing the correlation values. |
554
|
|
|
|
555
|
|
|
- Many more kwargs are available, e.g., "hovertemplate" to control |
556
|
|
|
the legend hover template, or options to adjust the borderwidth |
557
|
|
|
and opacity of the heatmap. For a comprehensive list of |
558
|
|
|
available kwargs, please refer to the Plotly Heatmap documentation. |
559
|
|
|
|
560
|
|
|
Kwargs can be supplied through a dictionary of key-value pairs |
561
|
|
|
(see above) and can be found in Plotly Heatmap documentation. |
562
|
|
|
|
563
|
|
|
Returns |
564
|
|
|
------- |
565
|
|
|
heatmap : plotly.graph_objs._figure.Figure |
566
|
|
|
A Plotly Figure object representing the heatmap visualization of |
567
|
|
|
feature correlations. |
568
|
|
|
|
569
|
|
|
""" |
570
|
|
|
# Validate Inputs |
571
|
|
|
_validate_input_range(threshold, "threshold", -1, 1) |
572
|
|
|
_validate_input_bool(annot, "annot") |
573
|
|
|
|
574
|
|
|
data = pd.DataFrame(data).iloc[:, ::-1] |
575
|
|
|
|
576
|
|
|
corr = corr_mat( |
577
|
|
|
data, |
578
|
|
|
split=split, |
579
|
|
|
threshold=threshold, |
580
|
|
|
target=target, |
581
|
|
|
method=method, |
582
|
|
|
colored=False, |
583
|
|
|
) |
584
|
|
|
|
585
|
|
|
mask = np.zeros_like(corr, dtype=bool) |
586
|
|
|
|
587
|
|
|
if target is None: |
588
|
|
|
mask = np.triu(np.ones_like(corr, dtype=bool)) |
589
|
|
|
np.fill_diagonal(corr.to_numpy(), np.nan) |
590
|
|
|
corr = corr.where(mask == 1) |
591
|
|
|
else: |
592
|
|
|
corr = corr.iloc[::-1, :] |
593
|
|
|
|
594
|
|
|
vmax = np.round(np.nanmax(corr) - 0.05, 2) |
595
|
|
|
vmin = np.round(np.nanmin(corr) + 0.05, 2) |
596
|
|
|
|
597
|
|
|
vmax = -vmin if split == "neg" else vmax |
598
|
|
|
vmin = -vmax if split == "pos" else vmin |
599
|
|
|
|
600
|
|
|
vtext = corr.round(2).fillna("") if annot else None |
601
|
|
|
|
602
|
|
|
corr_columns = corr.columns |
603
|
|
|
corr_index = corr.index |
604
|
|
|
|
605
|
|
|
if isinstance(corr_columns, pd.MultiIndex): |
606
|
|
|
corr_columns = ["-".join(col) for col in corr.columns] |
607
|
|
|
|
608
|
|
|
if isinstance(corr_index, pd.MultiIndex): |
609
|
|
|
corr_index = ["-".join(idx) for idx in corr.index] |
610
|
|
|
|
611
|
|
|
# Specify kwargs for the heatmap |
612
|
|
|
kwargs = { |
613
|
|
|
"colorscale": cmap, |
614
|
|
|
"zmax": vmax, |
615
|
|
|
"zmin": vmin, |
616
|
|
|
"text": vtext, |
617
|
|
|
"texttemplate": "%{text}", |
618
|
|
|
"textfont": {"size": 12}, |
619
|
|
|
"x": corr_columns, |
620
|
|
|
"y": corr_index, |
621
|
|
|
"z": corr, |
622
|
|
|
**kwargs, |
623
|
|
|
} |
624
|
|
|
|
625
|
|
|
# Draw heatmap with masked corr and default settings |
626
|
|
|
heatmap = go.Figure( |
627
|
|
|
data=go.Heatmap( |
628
|
|
|
hoverongaps=False, |
629
|
|
|
xgap=1, |
630
|
|
|
ygap=1, |
631
|
|
|
**kwargs, |
632
|
|
|
), |
633
|
|
|
) |
634
|
|
|
|
635
|
|
|
dpi = None |
636
|
|
|
try: |
637
|
|
|
for monitor in get_monitors(): |
638
|
|
|
if monitor.is_primary: |
639
|
|
|
if monitor.width_mm is None or monitor.height_mm is None: |
640
|
|
|
continue |
641
|
|
|
dpi = monitor.width / (monitor.width_mm / 25.4) |
642
|
|
|
break |
643
|
|
|
|
644
|
|
|
if dpi is None: |
645
|
|
|
monitor = get_monitors()[0] |
646
|
|
|
if monitor.width_mm is None or monitor.height_mm is None: |
647
|
|
|
dpi = 96 # more or less arbitrary default value |
648
|
|
|
else: |
649
|
|
|
dpi = monitor.width / (monitor.width_mm / 25.4) |
650
|
|
|
except ScreenInfoError: |
651
|
|
|
dpi = 96 |
652
|
|
|
|
653
|
|
|
heatmap.update_layout( |
654
|
|
|
title=f"Feature-correlation ({method})", |
655
|
|
|
title_font={"size": 24}, |
656
|
|
|
title_x=0.5, |
657
|
|
|
autosize=True, |
658
|
|
|
width=figsize[0] * dpi, |
659
|
|
|
height=(figsize[1] + 1) * dpi, |
660
|
|
|
xaxis={"autorange": "reversed"}, |
661
|
|
|
) |
662
|
|
|
|
663
|
|
|
return heatmap |
664
|
|
|
|
665
|
|
|
|
666
|
|
|
def dist_plot( |
667
|
|
|
data: pd.DataFrame, |
668
|
|
|
mean_color: str = "orange", |
669
|
|
|
size: int = 3, |
670
|
|
|
fill_range: tuple = (0.025, 0.975), |
671
|
|
|
showall: bool = False, |
672
|
|
|
kde_kws: dict[str, Any] | None = None, |
673
|
|
|
rug_kws: dict[str, Any] | None = None, |
674
|
|
|
fill_kws: dict[str, Any] | None = None, |
675
|
|
|
font_kws: dict[str, Any] | None = None, |
676
|
|
|
) -> None | Any: # noqa: ANN401 |
677
|
|
|
"""2D visualization of the distribution of non binary numerical features. |
678
|
|
|
|
679
|
|
|
Parameters |
680
|
|
|
---------- |
681
|
|
|
data : pd.DataFrame |
682
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
683
|
|
|
is provided, the index/column information is used to label the plots |
684
|
|
|
mean_color : str, optional |
685
|
|
|
Color of the vertical line indicating the mean of the data, by default "orange" |
686
|
|
|
size : float, optional |
687
|
|
|
Controls the plot size, by default 3 |
688
|
|
|
fill_range : tuple, optional |
689
|
|
|
Set the quantiles for shading. Default spans 95% of the data, which is about \ |
690
|
|
|
two std. deviations above and below the mean, by default (0.025, 0.975) |
691
|
|
|
showall : bool, optional |
692
|
|
|
Set to True to remove the output limit of 20 plots, by default False |
693
|
|
|
kde_kws : dict[str, Any], optional |
694
|
|
|
Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \ |
695
|
|
|
"linewidth": 1.5, "bw_adjust": 0.8} |
696
|
|
|
rug_kws : dict[str, Any], optional |
697
|
|
|
Keyword arguments for rugplot(), by default {"color": "#ff3333", \ |
698
|
|
|
"alpha": 0.15, "lw": 3, "height": 0.075} |
699
|
|
|
fill_kws : dict[str, Any], optional |
700
|
|
|
Keyword arguments to control the fill, by default {"color": "#80d4ff", \ |
701
|
|
|
"alpha": 0.2} |
702
|
|
|
font_kws : dict[str, Any], optional |
703
|
|
|
Keyword arguments to control the font, by default {"color": "#111111", \ |
704
|
|
|
"weight": "normal", "size": 11} |
705
|
|
|
|
706
|
|
|
Returns |
707
|
|
|
------- |
708
|
|
|
ax: matplotlib Axes |
709
|
|
|
Returns the Axes object with the plot for further tweaking. |
710
|
|
|
|
711
|
|
|
""" |
712
|
|
|
# Validate Inputs |
713
|
|
|
_validate_input_range(fill_range[0], "fill_range_lower", 0, 1) |
714
|
|
|
_validate_input_range(fill_range[1], "fill_range_upper", 0, 1) |
715
|
|
|
_validate_input_smaller(fill_range[0], fill_range[1], "fill_range") |
716
|
|
|
_validate_input_bool(showall, "showall") |
717
|
|
|
|
718
|
|
|
# Handle dictionary defaults |
719
|
|
|
kde_kws = ( |
720
|
|
|
{"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8} if kde_kws is None else kde_kws.copy() |
721
|
|
|
) |
722
|
|
|
rug_kws = ( |
723
|
|
|
{"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075} |
724
|
|
|
if rug_kws is None |
725
|
|
|
else rug_kws.copy() |
726
|
|
|
) |
727
|
|
|
fill_kws = {"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy() |
728
|
|
|
font_kws = ( |
729
|
|
|
{"color": "#111111", "weight": "normal", "size": 11} |
730
|
|
|
if font_kws is None |
731
|
|
|
else font_kws.copy() |
732
|
|
|
) |
733
|
|
|
|
734
|
|
|
data = pd.DataFrame(data.copy()).dropna(axis=1, how="all") |
735
|
|
|
df = data.copy() |
736
|
|
|
data = data.loc[:, data.nunique() > 2] # noqa: PLR2004 |
737
|
|
|
if data.shape[0] > 10000: # noqa: PLR2004 |
738
|
|
|
data = data.sample(n=10000, random_state=408) |
739
|
|
|
print( |
740
|
|
|
"Large dataset detected, using 10000 random samples for the plots. Summary" |
741
|
|
|
" statistics are still based on the entire dataset.", |
742
|
|
|
) |
743
|
|
|
cols = list(data.select_dtypes(include=["number"]).columns) |
744
|
|
|
data = data[cols] |
745
|
|
|
|
746
|
|
|
if not cols: |
747
|
|
|
print("No columns with numeric data were detected.") |
748
|
|
|
return None |
749
|
|
|
|
750
|
|
|
if len(cols) >= 20 and not showall: # noqa: PLR2004 |
751
|
|
|
print( |
752
|
|
|
"Note: The number of non binary numerical features is very large " |
753
|
|
|
f"({len(cols)}), please consider splitting the data. Showing plots for " |
754
|
|
|
"the first 20 numerical features. Override this by setting showall=True.", |
755
|
|
|
) |
756
|
|
|
cols = cols[:20] |
757
|
|
|
if not cols: |
758
|
|
|
print("No columns with numeric data were detected.") |
759
|
|
|
return None |
760
|
|
|
|
761
|
|
|
for col in cols: |
762
|
|
|
col_data = data[col].dropna(axis=0) |
763
|
|
|
col_df = df[col].dropna(axis=0) |
764
|
|
|
|
765
|
|
|
g = sns.displot( |
766
|
|
|
col_data, |
767
|
|
|
kind="kde", |
768
|
|
|
rug=True, |
769
|
|
|
height=size, |
770
|
|
|
aspect=5, |
771
|
|
|
legend=False, |
772
|
|
|
rug_kws=rug_kws, |
773
|
|
|
**kde_kws, |
774
|
|
|
) |
775
|
|
|
|
776
|
|
|
# Vertical lines and fill |
777
|
|
|
x, y = g.axes[0, 0].lines[0].get_xydata().T |
778
|
|
|
g.axes[0, 0].fill_between( |
779
|
|
|
x, |
780
|
|
|
y, |
781
|
|
|
where=( |
782
|
|
|
(x >= np.quantile(col_df, fill_range[0])) |
783
|
|
|
& (x <= np.quantile(col_df, fill_range[1])) |
784
|
|
|
), |
785
|
|
|
label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%", |
786
|
|
|
**fill_kws, |
787
|
|
|
) |
788
|
|
|
|
789
|
|
|
mean = np.mean(col_df) |
790
|
|
|
std = scipy.stats.tstd(col_df) |
791
|
|
|
g.axes[0, 0].vlines( |
792
|
|
|
x=mean, |
793
|
|
|
ymin=0, |
794
|
|
|
ymax=np.interp(mean, x, y), |
795
|
|
|
ls="dotted", |
796
|
|
|
color=mean_color, |
797
|
|
|
lw=2, |
798
|
|
|
label="mean", |
799
|
|
|
) |
800
|
|
|
g.axes[0, 0].vlines( |
801
|
|
|
x=np.median(col_df), |
802
|
|
|
ymin=0, |
803
|
|
|
ymax=np.interp(np.median(col_df), x, y), |
804
|
|
|
ls=":", |
805
|
|
|
color=".3", |
806
|
|
|
label="median", |
807
|
|
|
) |
808
|
|
|
g.axes[0, 0].vlines( |
809
|
|
|
x=[mean - std, mean + std], |
810
|
|
|
ymin=0, |
811
|
|
|
ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)], |
812
|
|
|
ls=":", |
813
|
|
|
color=".5", |
814
|
|
|
label="\u03bc \u00b1 \u03c3", |
815
|
|
|
) |
816
|
|
|
|
817
|
|
|
g.axes[0, 0].set_ylim(0) |
818
|
|
|
g.axes[0, 0].set_xlim( |
819
|
|
|
g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05, |
820
|
|
|
g.axes[0, 0].get_xlim()[1] * 1.03, |
821
|
|
|
) |
822
|
|
|
|
823
|
|
|
# Annotations and legend |
824
|
|
|
g.axes[0, 0].text( |
825
|
|
|
0.005, |
826
|
|
|
0.9, |
827
|
|
|
f"Mean: {mean:.2f}", |
828
|
|
|
fontdict=font_kws, |
829
|
|
|
transform=g.axes[0, 0].transAxes, |
830
|
|
|
) |
831
|
|
|
g.axes[0, 0].text( |
832
|
|
|
0.005, |
833
|
|
|
0.7, |
834
|
|
|
f"Std. dev: {std:.2f}", |
835
|
|
|
fontdict=font_kws, |
836
|
|
|
transform=g.axes[0, 0].transAxes, |
837
|
|
|
) |
838
|
|
|
g.axes[0, 0].text( |
839
|
|
|
0.005, |
840
|
|
|
0.5, |
841
|
|
|
f"Skew: {scipy.stats.skew(col_df):.2f}", |
842
|
|
|
fontdict=font_kws, |
843
|
|
|
transform=g.axes[0, 0].transAxes, |
844
|
|
|
) |
845
|
|
|
g.axes[0, 0].text( |
846
|
|
|
0.005, |
847
|
|
|
0.3, |
848
|
|
|
f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}", # Excess Kurtosis |
849
|
|
|
fontdict=font_kws, |
850
|
|
|
transform=g.axes[0, 0].transAxes, |
851
|
|
|
) |
852
|
|
|
g.axes[0, 0].text( |
853
|
|
|
0.005, |
854
|
|
|
0.1, |
855
|
|
|
f"Count: {len(col_df)}", |
856
|
|
|
fontdict=font_kws, |
857
|
|
|
transform=g.axes[0, 0].transAxes, |
858
|
|
|
) |
859
|
|
|
g.axes[0, 0].legend(loc="upper right") |
860
|
|
|
|
861
|
|
|
return g.axes[0, 0] |
|
|
|
|
862
|
|
|
|
863
|
|
|
|
864
|
|
|
def missingval_plot( # noqa: PLR0915 |
865
|
|
|
data: pd.DataFrame, |
866
|
|
|
cmap: str = "PuBuGn", |
867
|
|
|
figsize: tuple = (20, 20), |
868
|
|
|
sort: bool = False, |
869
|
|
|
spine_color: str = "#EEEEEE", |
870
|
|
|
) -> GridSpec: |
871
|
|
|
"""Two-dimensional visualization of the missing values in a dataset. |
872
|
|
|
|
873
|
|
|
Parameters |
874
|
|
|
---------- |
875
|
|
|
data : pd.DataFrame |
876
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
877
|
|
|
is provided, the index/column information is used to label the plots |
878
|
|
|
cmap : str, optional |
879
|
|
|
Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \ |
880
|
|
|
found in the matplotlib documentation, by default "PuBuGn" |
881
|
|
|
figsize : tuple, optional |
882
|
|
|
Use to control the figure size, by default (20, 20) |
883
|
|
|
sort : bool, optional |
884
|
|
|
Sort columns based on missing values in descending order and drop columns \ |
885
|
|
|
without any missing values, by default False |
886
|
|
|
spine_color : str, optional |
887
|
|
|
Set to "None" to hide the spines on all plots or use any valid matplotlib \ |
888
|
|
|
color argument, by default "#EEEEEE" |
889
|
|
|
|
890
|
|
|
Returns |
891
|
|
|
------- |
892
|
|
|
GridSpec |
893
|
|
|
gs: Figure with array of Axes objects |
894
|
|
|
|
895
|
|
|
""" |
896
|
|
|
# Validate Inputs |
897
|
|
|
_validate_input_bool(sort, "sort") |
898
|
|
|
|
899
|
|
|
data = pd.DataFrame(data) |
900
|
|
|
|
901
|
|
|
if sort: |
902
|
|
|
mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False) |
903
|
|
|
final_cols = ( |
904
|
|
|
mv_cols_sorted.drop( |
905
|
|
|
mv_cols_sorted[mv_cols_sorted.to_numpy() == 0].keys().tolist(), |
906
|
|
|
) |
907
|
|
|
.keys() |
908
|
|
|
.tolist() |
909
|
|
|
) |
910
|
|
|
data = data[final_cols] |
911
|
|
|
print("Displaying only columns with missing values.") |
912
|
|
|
|
913
|
|
|
# Identify missing values |
914
|
|
|
mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values() |
915
|
|
|
total_datapoints = data.shape[0] * data.shape[1] |
916
|
|
|
|
917
|
|
|
if mv_total == 0: |
918
|
|
|
print("No missing values found in the dataset.") |
919
|
|
|
return None |
920
|
|
|
|
921
|
|
|
# Create figure and axes |
922
|
|
|
fig = plt.figure(figsize=figsize) |
923
|
|
|
gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05) |
924
|
|
|
ax1 = fig.add_subplot(gs[:1, :5]) |
925
|
|
|
ax2 = fig.add_subplot(gs[1:, :5]) |
926
|
|
|
ax3 = fig.add_subplot(gs[:1, 5:]) |
927
|
|
|
ax4 = fig.add_subplot(gs[1:, 5:]) |
928
|
|
|
|
929
|
|
|
# ax1 - Barplot |
930
|
|
|
colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols)) # color bars by height |
931
|
|
|
ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors) |
932
|
|
|
ax1.get_xaxis().set_visible(False) |
933
|
|
|
ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5)) |
934
|
|
|
ax1.set_ylim(0, np.max(mv_cols_ratio) * 100) |
935
|
|
|
ax1.grid(linestyle=":", linewidth=1) |
936
|
|
|
ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=1)) |
937
|
|
|
ax1.tick_params(axis="y", colors="#111111", length=1) |
938
|
|
|
|
939
|
|
|
# annotate values on top of the bars |
940
|
|
|
for rect, label in zip(ax1.patches, mv_cols, strict=True): |
941
|
|
|
height = rect.get_height() |
942
|
|
|
ax1.text( |
943
|
|
|
rect.get_x() + rect.get_width() / 2, |
944
|
|
|
height + max(np.log(1 + height / 6), 0.075), |
945
|
|
|
label, |
946
|
|
|
ha="center", |
947
|
|
|
va="bottom", |
948
|
|
|
rotation=90, |
949
|
|
|
alpha=0.5, |
950
|
|
|
fontsize="11", |
951
|
|
|
) |
952
|
|
|
|
953
|
|
|
ax1.set_frame_on(True) |
954
|
|
|
for spine in ax1.spines.values(): |
955
|
|
|
spine.set_visible(True) |
956
|
|
|
spine.set_color(spine_color) |
957
|
|
|
ax1.spines["top"].set_color(None) |
958
|
|
|
|
959
|
|
|
# ax2 - Heatmap |
960
|
|
|
sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2) |
961
|
|
|
ax2.set_yticks(np.round(ax2.get_yticks()[::5], -1)) |
962
|
|
|
ax2.set_yticklabels(ax2.get_yticks()) |
963
|
|
|
ax2.set_xticklabels( |
964
|
|
|
ax2.get_xticklabels(), |
965
|
|
|
horizontalalignment="center", |
966
|
|
|
fontweight="light", |
967
|
|
|
fontsize="12", |
968
|
|
|
) |
969
|
|
|
ax2.tick_params(length=1, colors="#111111") |
970
|
|
|
for spine in ax2.spines.values(): |
971
|
|
|
spine.set_visible(True) |
972
|
|
|
spine.set_color(spine_color) |
973
|
|
|
|
974
|
|
|
# ax3 - Summary |
975
|
|
|
fontax3 = {"color": "#111111", "weight": "normal", "size": 14} |
976
|
|
|
ax3.get_xaxis().set_visible(False) |
977
|
|
|
ax3.get_yaxis().set_visible(False) |
978
|
|
|
ax3.set(frame_on=False) |
979
|
|
|
|
980
|
|
|
ax3.text( |
981
|
|
|
0.025, |
982
|
|
|
0.875, |
983
|
|
|
f"Total: {np.round(total_datapoints/1000,1)}K", |
984
|
|
|
transform=ax3.transAxes, |
985
|
|
|
fontdict=fontax3, |
986
|
|
|
) |
987
|
|
|
ax3.text( |
988
|
|
|
0.025, |
989
|
|
|
0.675, |
990
|
|
|
f"Missing: {np.round(mv_total/1000,1)}K", |
991
|
|
|
transform=ax3.transAxes, |
992
|
|
|
fontdict=fontax3, |
993
|
|
|
) |
994
|
|
|
ax3.text( |
995
|
|
|
0.025, |
996
|
|
|
0.475, |
997
|
|
|
f"Relative: {np.round(mv_total/total_datapoints*100,1)}%", |
998
|
|
|
transform=ax3.transAxes, |
999
|
|
|
fontdict=fontax3, |
1000
|
|
|
) |
1001
|
|
|
ax3.text( |
1002
|
|
|
0.025, |
1003
|
|
|
0.275, |
1004
|
|
|
f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%", |
1005
|
|
|
transform=ax3.transAxes, |
1006
|
|
|
fontdict=fontax3, |
1007
|
|
|
) |
1008
|
|
|
ax3.text( |
1009
|
|
|
0.025, |
1010
|
|
|
0.075, |
1011
|
|
|
f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%", |
1012
|
|
|
transform=ax3.transAxes, |
1013
|
|
|
fontdict=fontax3, |
1014
|
|
|
) |
1015
|
|
|
|
1016
|
|
|
# ax4 - Scatter plot |
1017
|
|
|
ax4.get_yaxis().set_visible(False) |
1018
|
|
|
for spine in ax4.spines.values(): |
1019
|
|
|
spine.set_color(spine_color) |
1020
|
|
|
ax4.tick_params(axis="x", colors="#111111", length=1) |
1021
|
|
|
|
1022
|
|
|
ax4.scatter( |
1023
|
|
|
mv_rows, |
1024
|
|
|
range(len(mv_rows)), |
1025
|
|
|
s=mv_rows, |
1026
|
|
|
c=mv_rows, |
1027
|
|
|
cmap=cmap, |
1028
|
|
|
marker=".", |
1029
|
|
|
vmin=1, |
1030
|
|
|
) |
1031
|
|
|
ax4.set_ylim((0, len(mv_rows))[::-1]) # limit and invert y-axis |
1032
|
|
|
ax4.set_xlim(0, max(mv_rows) + 0.5) |
1033
|
|
|
ax4.grid(linestyle=":", linewidth=1) |
1034
|
|
|
|
1035
|
|
|
gs.figure.suptitle( |
1036
|
|
|
"Missing value plot", |
1037
|
|
|
x=0.45, |
1038
|
|
|
y=0.94, |
1039
|
|
|
fontsize=18, |
1040
|
|
|
color="#111111", |
1041
|
|
|
) |
1042
|
|
|
|
1043
|
|
|
return gs |
1044
|
|
|
|