1
|
|
|
""" |
2
|
|
|
Functions for descriptive analytics. |
3
|
|
|
|
4
|
|
|
:author: Andreas Kanz |
5
|
|
|
|
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
# Imports |
9
|
|
|
import matplotlib.pyplot as plt |
10
|
|
|
import matplotlib.ticker as ticker |
11
|
|
|
import numpy as np |
12
|
|
|
import pandas as pd |
13
|
|
|
import scipy |
14
|
|
|
import seaborn as sns |
15
|
|
|
from matplotlib.colors import LinearSegmentedColormap, to_rgb |
16
|
|
|
from typing import Any, Dict, Optional, Tuple, Union |
17
|
|
|
|
18
|
|
|
from klib.utils import ( |
19
|
|
|
_corr_selector, |
20
|
|
|
_missing_vals, |
21
|
|
|
_validate_input_bool, |
22
|
|
|
_validate_input_int, |
23
|
|
|
_validate_input_range, |
24
|
|
|
_validate_input_smaller, |
25
|
|
|
_validate_input_sum_larger, |
26
|
|
|
) |
27
|
|
|
|
28
|
|
|
__all__ = ["cat_plot", "corr_mat", "corr_plot", "dist_plot", "missingval_plot"] |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
# Functions |
32
|
|
|
|
33
|
|
|
# Categorical Plot |
34
|
|
|
def cat_plot( |
35
|
|
|
data: pd.DataFrame, |
36
|
|
|
figsize: Tuple = (18, 18), |
37
|
|
|
top: int = 3, |
38
|
|
|
bottom: int = 3, |
39
|
|
|
bar_color_top: str = "#5ab4ac", |
40
|
|
|
bar_color_bottom: str = "#d8b365", |
41
|
|
|
): |
42
|
|
|
""" Two-dimensional visualization of the number and frequency of categorical features. |
43
|
|
|
|
44
|
|
|
Parameters |
45
|
|
|
---------- |
46
|
|
|
data : pd.DataFrame |
47
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
48
|
|
|
is provided, the index/column information is used to label the plots |
49
|
|
|
figsize : Tuple, optional |
50
|
|
|
Use to control the figure size, by default (18, 18) |
51
|
|
|
top : int, optional |
52
|
|
|
Show the "top" most frequent values in a column, by default 3 |
53
|
|
|
bottom : int, optional |
54
|
|
|
Show the "bottom" most frequent values in a column, by default 3 |
55
|
|
|
bar_color_top : str, optional |
56
|
|
|
Use to control the color of the bars indicating the most common values, by \ |
57
|
|
|
default "#5ab4ac" |
58
|
|
|
bar_color_bottom : str, optional |
59
|
|
|
Use to control the color of the bars indicating the least common values, by \ |
60
|
|
|
default "#d8b365" |
61
|
|
|
cmap : str, optional |
62
|
|
|
The mapping from data values to color space, by default "BrBG" |
63
|
|
|
|
64
|
|
|
Returns |
65
|
|
|
------- |
66
|
|
|
Gridspec |
67
|
|
|
gs: Figure with array of Axes objects |
68
|
|
|
""" |
69
|
|
|
|
70
|
|
|
# Validate Inputs |
71
|
|
|
_validate_input_int(top, "top") |
72
|
|
|
_validate_input_int(bottom, "bottom") |
73
|
|
|
_validate_input_range(top, "top", 0, data.shape[1]) |
74
|
|
|
_validate_input_range(bottom, "bottom", 0, data.shape[1]) |
75
|
|
|
_validate_input_sum_larger(1, "top and bottom", top, bottom) |
76
|
|
|
|
77
|
|
|
data = pd.DataFrame(data).copy() |
78
|
|
|
cols = data.select_dtypes(exclude=["number"]).columns.tolist() |
79
|
|
|
data = data[cols] |
80
|
|
|
|
81
|
|
|
if len(cols) == 0: |
82
|
|
|
print("No columns with categorical data were detected.") |
83
|
|
|
return None |
84
|
|
|
|
85
|
|
|
for col in data.columns: |
86
|
|
|
if data[col].dtype.name in ("category", "string"): |
87
|
|
|
data[col] = data[col].astype("object") |
88
|
|
|
|
89
|
|
|
fig = plt.figure(figsize=figsize) |
90
|
|
|
gs = fig.add_gridspec(nrows=6, ncols=len(cols), wspace=0.21) |
91
|
|
|
|
92
|
|
|
for count, col in enumerate(cols): |
93
|
|
|
n_unique = data[col].nunique(dropna=True) |
94
|
|
|
value_counts = data[col].value_counts() |
95
|
|
|
lim_top, lim_bot = top, bottom |
96
|
|
|
|
97
|
|
|
if n_unique < top + bottom: |
98
|
|
|
lim_top = int(n_unique // 2) |
99
|
|
|
lim_bot = int(n_unique // 2) + 1 |
100
|
|
|
|
101
|
|
|
if n_unique <= 2: |
102
|
|
|
lim_top = lim_bot = int(n_unique // 2) |
103
|
|
|
|
104
|
|
|
value_counts_top = value_counts[0:lim_top] |
105
|
|
|
value_counts_idx_top = value_counts_top.index.tolist() |
106
|
|
|
value_counts_bot = value_counts[-lim_bot:] |
107
|
|
|
value_counts_idx_bot = value_counts_bot.index.tolist() |
108
|
|
|
|
109
|
|
|
if top == 0: |
110
|
|
|
value_counts_top = value_counts_idx_top = [] |
111
|
|
|
|
112
|
|
|
if bottom == 0: |
113
|
|
|
value_counts_bot = value_counts_idx_bot = [] |
114
|
|
|
|
115
|
|
|
data.loc[data[col].isin(value_counts_idx_top), col] = 10 |
116
|
|
|
data.loc[data[col].isin(value_counts_idx_bot), col] = 0 |
117
|
|
|
data.loc[((data[col] != 10) & (data[col] != 0)), col] = 5 |
118
|
|
|
data[col] = data[col].rolling(2, min_periods=1).mean() |
119
|
|
|
|
120
|
|
|
value_counts_idx_top = [elem[:20] for elem in value_counts_idx_top] |
121
|
|
|
value_counts_idx_bot = [elem[:20] for elem in value_counts_idx_bot] |
122
|
|
|
sum_top = sum(value_counts_top) |
123
|
|
|
sum_bot = sum(value_counts_bot) |
124
|
|
|
|
125
|
|
|
# Barcharts |
126
|
|
|
ax_top = fig.add_subplot(gs[:1, count : count + 1]) |
127
|
|
|
ax_top.bar( |
128
|
|
|
value_counts_idx_top, value_counts_top, color=bar_color_top, width=0.85 |
129
|
|
|
) |
130
|
|
|
ax_top.bar( |
131
|
|
|
value_counts_idx_bot, value_counts_bot, color=bar_color_bottom, width=0.85 |
132
|
|
|
) |
133
|
|
|
ax_top.set(frame_on=False) |
134
|
|
|
ax_top.tick_params(axis="x", labelrotation=90) |
135
|
|
|
|
136
|
|
|
# Summary stats |
137
|
|
|
ax_bottom = fig.add_subplot(gs[1:2, count : count + 1]) |
138
|
|
|
plt.subplots_adjust(hspace=0.075) |
139
|
|
|
ax_bottom.get_yaxis().set_visible(False) |
140
|
|
|
ax_bottom.get_xaxis().set_visible(False) |
141
|
|
|
ax_bottom.set(frame_on=False) |
142
|
|
|
ax_bottom.text( |
143
|
|
|
0, |
144
|
|
|
0, |
145
|
|
|
f"Unique values: {n_unique}\n\n" |
146
|
|
|
f"Top {lim_top} vals: {sum_top} ({sum_top/data.shape[0]*100:.1f}%)\n" |
147
|
|
|
f"Bot {lim_bot} vals: {sum_bot} ({sum_bot/data.shape[0]*100:.1f}%)", |
148
|
|
|
transform=ax_bottom.transAxes, |
149
|
|
|
color="#111111", |
150
|
|
|
fontsize=11, |
151
|
|
|
) |
152
|
|
|
|
153
|
|
|
# Heatmap |
154
|
|
|
color_bot_rgb = to_rgb(bar_color_bottom) |
155
|
|
|
color_white = to_rgb("#FFFFFF") |
156
|
|
|
color_top_rgb = to_rgb(bar_color_top) |
157
|
|
|
cat_plot_cmap = LinearSegmentedColormap.from_list( |
158
|
|
|
"cat_plot_cmap", [color_bot_rgb, color_white, color_top_rgb], N=200 |
159
|
|
|
) |
160
|
|
|
ax_hm = fig.add_subplot(gs[2:, :]) |
161
|
|
|
sns.heatmap(data, cmap=cat_plot_cmap, cbar=False, vmin=0, vmax=10, ax=ax_hm) |
162
|
|
|
ax_hm.set_yticks(np.round(ax_hm.get_yticks()[0::5], -1)) |
163
|
|
|
ax_hm.set_yticklabels(ax_hm.get_yticks()) |
164
|
|
|
ax_hm.set_xticklabels( |
165
|
|
|
ax_hm.get_xticklabels(), |
166
|
|
|
horizontalalignment="center", |
167
|
|
|
fontweight="light", |
168
|
|
|
fontsize="medium", |
169
|
|
|
) |
170
|
|
|
ax_hm.tick_params(length=1, colors="#111111") |
171
|
|
|
gs.figure.suptitle( |
172
|
|
|
"Categorical data plot", x=0.5, y=0.91, fontsize=18, color="#111111" |
173
|
|
|
) |
174
|
|
|
|
175
|
|
|
return gs |
176
|
|
|
|
177
|
|
|
|
178
|
|
|
# Correlation Matrix |
179
|
|
|
def corr_mat( |
180
|
|
|
data: pd.DataFrame, |
181
|
|
|
split: Optional[ |
182
|
|
|
str |
183
|
|
|
] = None, # Optional[Literal['pos', 'neg', 'high', 'low']] = None, |
184
|
|
|
threshold: float = 0, |
185
|
|
|
target: Optional[Union[pd.DataFrame, pd.Series, np.ndarray, str]] = None, |
186
|
|
|
method: str = "pearson", # Literal['pearson', 'spearman', 'kendall'] = "pearson", |
187
|
|
|
colored: bool = True, |
188
|
|
|
) -> Union[pd.DataFrame, Any]: |
189
|
|
|
""" Returns a color-encoded correlation matrix. |
190
|
|
|
|
191
|
|
|
Parameters |
192
|
|
|
---------- |
193
|
|
|
data : pd.DataFrame |
194
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
195
|
|
|
is provided, the index/column information is used to label the plots |
196
|
|
|
split : Optional[str], optional |
197
|
|
|
Type of split to be performed, by default None |
198
|
|
|
{None, "pos", "neg", "high", "low"} |
199
|
|
|
threshold : float, optional |
200
|
|
|
Value between 0 and 1 to set the correlation threshold, by default 0 unless \ |
201
|
|
|
split = "high" or split = "low", in which case default is 0.3 |
202
|
|
|
target : Optional[Union[pd.DataFrame, str]], optional |
203
|
|
|
Specify target for correlation. E.g. label column to generate only the \ |
204
|
|
|
correlations between each feature and the label, by default None |
205
|
|
|
method : str, optional |
206
|
|
|
method: {"pearson", "spearman", "kendall"}, by default "pearson" |
207
|
|
|
* pearson: measures linear relationships and requires normally distributed \ |
208
|
|
|
and homoscedastic data. |
209
|
|
|
* spearman: ranked/ordinal correlation, measures monotonic relationships. |
210
|
|
|
* kendall: ranked/ordinal correlation, measures monotonic relationships. \ |
211
|
|
|
Computationally more expensive but more robust in smaller dataets than \ |
212
|
|
|
"spearman" |
213
|
|
|
colored : bool, optional |
214
|
|
|
If True the negative values in the correlation matrix are colored in red, by \ |
215
|
|
|
default True |
216
|
|
|
|
217
|
|
|
Returns |
218
|
|
|
------- |
219
|
|
|
Union[pd.DataFrame, pd.Styler] |
220
|
|
|
If colored = True - corr: Pandas Styler object |
221
|
|
|
If colored = False - corr: Pandas DataFrame |
222
|
|
|
""" |
223
|
|
|
|
224
|
|
|
# Validate Inputs |
225
|
|
|
_validate_input_range(threshold, "threshold", -1, 1) |
226
|
|
|
_validate_input_bool(colored, "colored") |
227
|
|
|
|
228
|
|
|
def color_negative_red(val): |
229
|
|
|
color = "#FF3344" if val < 0 else None |
230
|
|
|
return "color: %s" % color |
231
|
|
|
|
232
|
|
|
data = pd.DataFrame(data) |
233
|
|
|
|
234
|
|
|
if isinstance(target, (str, list, pd.Series, np.ndarray)): |
235
|
|
|
target_data = [] |
236
|
|
|
if isinstance(target, str): |
237
|
|
|
target_data = data[target] |
238
|
|
|
data = data.drop(target, axis=1) |
239
|
|
|
|
240
|
|
|
elif isinstance(target, (list, pd.Series, np.ndarray)): |
241
|
|
|
target_data = pd.Series(target) |
242
|
|
|
target = target_data.name |
243
|
|
|
|
244
|
|
|
corr = pd.DataFrame(data.corrwith(target_data, method=method)) |
245
|
|
|
corr = corr.sort_values(corr.columns[0], ascending=False) |
246
|
|
|
corr.columns = [target] |
247
|
|
|
|
248
|
|
|
else: |
249
|
|
|
corr = data.corr(method=method) |
250
|
|
|
|
251
|
|
|
corr = _corr_selector(corr, split=split, threshold=threshold) |
252
|
|
|
|
253
|
|
|
if colored: |
254
|
|
|
return corr.style.applymap(color_negative_red).format("{:.2f}", na_rep="-") |
255
|
|
|
return corr |
256
|
|
|
|
257
|
|
|
|
258
|
|
|
# Correlation matrix / heatmap |
259
|
|
|
def corr_plot( |
260
|
|
|
data: pd.DataFrame, |
261
|
|
|
split: Optional[str] = None, |
262
|
|
|
threshold: float = 0, |
263
|
|
|
target: Optional[Union[pd.Series, str]] = None, |
264
|
|
|
method: str = "pearson", |
265
|
|
|
cmap: str = "BrBG", |
266
|
|
|
figsize: Tuple = (12, 10), |
267
|
|
|
annot: bool = True, |
268
|
|
|
dev: bool = False, |
269
|
|
|
**kwargs, |
270
|
|
|
): |
271
|
|
|
""" Two-dimensional visualization of the correlation between feature-columns \ |
272
|
|
|
excluding NA values. |
273
|
|
|
|
274
|
|
|
Parameters |
275
|
|
|
---------- |
276
|
|
|
data : pd.DataFrame |
277
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
278
|
|
|
is provided, the index/column information is used to label the plots |
279
|
|
|
split : Optional[str], optional |
280
|
|
|
Type of split to be performed {None, "pos", "neg", "high", "low"}, by default \ |
281
|
|
|
None |
282
|
|
|
* None: visualize all correlations between the feature-columns |
283
|
|
|
* pos: visualize all positive correlations between the feature-columns \ |
284
|
|
|
above the threshold |
285
|
|
|
* neg: visualize all negative correlations between the feature-columns \ |
286
|
|
|
below the threshold |
287
|
|
|
* high: visualize all correlations between the feature-columns for \ |
288
|
|
|
which abs (corr) > threshold is True |
289
|
|
|
* low: visualize all correlations between the feature-columns for which \ |
290
|
|
|
abs(corr) < threshold is True |
291
|
|
|
|
292
|
|
|
threshold : float, optional |
293
|
|
|
Value between 0 and 1 to set the correlation threshold, by default 0 unless \ |
294
|
|
|
split = "high" or split = "low", in which case default is 0.3 |
295
|
|
|
target : Optional[Union[pd.Series, str]], optional |
296
|
|
|
Specify target for correlation. E.g. label column to generate only the \ |
297
|
|
|
correlations between each feature and the label, by default None |
298
|
|
|
method : str, optional |
299
|
|
|
method: {"pearson", "spearman", "kendall"}, by default "pearson" |
300
|
|
|
* pearson: measures linear relationships and requires normally \ |
301
|
|
|
distributed and homoscedastic data. |
302
|
|
|
* spearman: ranked/ordinal correlation, measures monotonic relationships. |
303
|
|
|
* kendall: ranked/ordinal correlation, measures monotonic relationships. \ |
304
|
|
|
Computationally more expensive but more robust in smaller dataets \ |
305
|
|
|
than "spearman". |
306
|
|
|
|
307
|
|
|
cmap : str, optional |
308
|
|
|
The mapping from data values to color space, matplotlib colormap name or \ |
309
|
|
|
object, or list of colors, by default "BrBG" |
310
|
|
|
figsize : Tuple, optional |
311
|
|
|
Use to control the figure size, by default (12, 10) |
312
|
|
|
annot : bool, optional |
313
|
|
|
Use to show or hide annotations, by default True |
314
|
|
|
dev : bool, optional |
315
|
|
|
Display figure settings in the plot by setting dev = True. If False, the \ |
316
|
|
|
settings are not displayed, by default False |
317
|
|
|
|
318
|
|
|
Keyword Arguments : optional |
319
|
|
|
Additional elements to control the visualization of the plot, e.g.: |
320
|
|
|
|
321
|
|
|
* mask: bool, default True |
322
|
|
|
If set to False the entire correlation matrix, including the upper \ |
323
|
|
|
triangle is shown. Set dev = False in this case to avoid overlap. |
324
|
|
|
* vmax: float, default is calculated from the given correlation \ |
325
|
|
|
coefficients. |
326
|
|
|
Value between -1 or vmin <= vmax <= 1, limits the range of the cbar. |
327
|
|
|
* vmin: float, default is calculated from the given correlation \ |
328
|
|
|
coefficients. |
329
|
|
|
Value between -1 <= vmin <= 1 or vmax, limits the range of the cbar. |
330
|
|
|
* linewidths: float, default 0.5 |
331
|
|
|
Controls the line-width inbetween the squares. |
332
|
|
|
* annot_kws: dict, default {"size" : 10} |
333
|
|
|
Controls the font size of the annotations. Only available when \ |
334
|
|
|
annot = True. |
335
|
|
|
* cbar_kws: dict, default {"shrink": .95, "aspect": 30} |
336
|
|
|
Controls the size of the colorbar. |
337
|
|
|
* Many more kwargs are available, i.e. "alpha" to control blending, or \ |
338
|
|
|
options to adjust labels, ticks ... |
339
|
|
|
|
340
|
|
|
Kwargs can be supplied through a dictionary of key-value pairs (see above). |
341
|
|
|
|
342
|
|
|
Returns |
343
|
|
|
------- |
344
|
|
|
ax: matplotlib Axes |
345
|
|
|
Returns the Axes object with the plot for further tweaking. |
346
|
|
|
""" |
347
|
|
|
|
348
|
|
|
# Validate Inputs |
349
|
|
|
_validate_input_range(threshold, "threshold", -1, 1) |
350
|
|
|
_validate_input_bool(annot, "annot") |
351
|
|
|
_validate_input_bool(dev, "dev") |
352
|
|
|
|
353
|
|
|
data = pd.DataFrame(data) |
354
|
|
|
|
355
|
|
|
corr = corr_mat( |
356
|
|
|
data, |
357
|
|
|
split=split, |
358
|
|
|
threshold=threshold, |
359
|
|
|
target=target, |
360
|
|
|
method=method, |
361
|
|
|
colored=False, |
362
|
|
|
) |
363
|
|
|
|
364
|
|
|
mask = np.zeros_like(corr, dtype=np.bool) |
365
|
|
|
|
366
|
|
|
if target is None: |
367
|
|
|
mask = np.triu(np.ones_like(corr, dtype=np.bool)) |
368
|
|
|
|
369
|
|
|
vmax = np.round(np.nanmax(corr.where(~mask)) - 0.05, 2) |
370
|
|
|
vmin = np.round(np.nanmin(corr.where(~mask)) + 0.05, 2) |
371
|
|
|
|
372
|
|
|
fig, ax = plt.subplots(figsize=figsize) |
373
|
|
|
|
374
|
|
|
# Specify kwargs for the heatmap |
375
|
|
|
kwargs = { |
376
|
|
|
"mask": mask, |
377
|
|
|
"cmap": cmap, |
378
|
|
|
"annot": annot, |
379
|
|
|
"vmax": vmax, |
380
|
|
|
"vmin": vmin, |
381
|
|
|
"linewidths": 0.5, |
382
|
|
|
"annot_kws": {"size": 10}, |
383
|
|
|
"cbar_kws": {"shrink": 0.95, "aspect": 30}, |
384
|
|
|
**kwargs, |
385
|
|
|
} |
386
|
|
|
|
387
|
|
|
# Draw heatmap with mask and default settings |
388
|
|
|
sns.heatmap(corr, center=0, fmt=".2f", **kwargs) |
389
|
|
|
|
390
|
|
|
ax.set_title(f"Feature-correlation ({method})", fontdict={"fontsize": 18}) |
391
|
|
|
|
392
|
|
|
# Settings |
393
|
|
|
if dev: |
394
|
|
|
fig.suptitle( |
395
|
|
|
f"\ |
396
|
|
|
Settings (dev-mode): \n\ |
397
|
|
|
- split-mode: {split} \n\ |
398
|
|
|
- threshold: {threshold} \n\ |
399
|
|
|
- method: {method} \n\ |
400
|
|
|
- annotations: {annot} \n\ |
401
|
|
|
- cbar: \n\ |
402
|
|
|
- vmax: {vmax} \n\ |
403
|
|
|
- vmin: {vmin} \n\ |
404
|
|
|
- linewidths: {kwargs['linewidths']} \n\ |
405
|
|
|
- annot_kws: {kwargs['annot_kws']} \n\ |
406
|
|
|
- cbar_kws: {kwargs['cbar_kws']}", |
407
|
|
|
fontsize=12, |
408
|
|
|
color="gray", |
409
|
|
|
x=0.35, |
410
|
|
|
y=0.85, |
411
|
|
|
ha="left", |
412
|
|
|
) |
413
|
|
|
|
414
|
|
|
return ax |
415
|
|
|
|
416
|
|
|
|
417
|
|
|
# Distribution plot |
418
|
|
|
def dist_plot( |
419
|
|
|
data: pd.DataFrame, |
420
|
|
|
mean_color: str = "orange", |
421
|
|
|
size: int = 2.5, |
422
|
|
|
fill_range: Tuple = (0.025, 0.975), |
423
|
|
|
showall: bool = False, |
424
|
|
|
kde_kws: Dict[str, Any] = None, |
425
|
|
|
rug_kws: Dict[str, Any] = None, |
426
|
|
|
fill_kws: Dict[str, Any] = None, |
427
|
|
|
font_kws: Dict[str, Any] = None, |
428
|
|
|
): |
429
|
|
|
""" Two-dimensional visualization of the distribution of non binary numerical features. |
430
|
|
|
|
431
|
|
|
Parameters |
432
|
|
|
---------- |
433
|
|
|
data : pd.DataFrame |
434
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
435
|
|
|
is provided, the index/column information is used to label the plots |
436
|
|
|
mean_color : str, optional |
437
|
|
|
Color of the vertical line indicating the mean of the data, by default "orange" |
438
|
|
|
size : int, optional |
439
|
|
|
Controls the plot size, by default 2.5 |
440
|
|
|
fill_range : Tuple, optional |
441
|
|
|
Set the quantiles for shading. Default spans 95% of the data, which is about \ |
442
|
|
|
two std. deviations above and below the mean, by default (0.025, 0.975) |
443
|
|
|
showall : bool, optional |
444
|
|
|
Set to True to remove the output limit of 20 plots, by default False |
445
|
|
|
kde_kws : Dict[str, Any], optional |
446
|
|
|
Keyword arguments for kdeplot(), by default {"color": "k", "alpha": 0.75, \ |
447
|
|
|
"linewidth": 1.5, "bw_adjust": 0.8} |
448
|
|
|
rug_kws : Dict[str, Any], optional |
449
|
|
|
Keyword arguments for rugplot(), by default {"color": "#ff3333", \ |
450
|
|
|
"alpha": 0.15, "lw": 3, "height": 0.075} |
451
|
|
|
fill_kws : Dict[str, Any], optional |
452
|
|
|
Keyword arguments to control the fill, by default {"color": "#80d4ff", \ |
453
|
|
|
"alpha": 0.2} |
454
|
|
|
font_kws : Dict[str, Any], optional |
455
|
|
|
Keyword arguments to control the font, by default {"color": "#111111", \ |
456
|
|
|
"weight": "normal", "size": 11} |
457
|
|
|
|
458
|
|
|
Returns |
459
|
|
|
------- |
460
|
|
|
ax: matplotlib Axes |
461
|
|
|
Returns the Axes object with the plot for further tweaking. |
462
|
|
|
""" |
463
|
|
|
|
464
|
|
|
# Validate Inputs |
465
|
|
|
_validate_input_range(fill_range[0], "fill_range_lower", 0, 1) |
466
|
|
|
_validate_input_range(fill_range[1], "fill_range_upper", 0, 1) |
467
|
|
|
_validate_input_smaller(fill_range[0], fill_range[1], "fill_range") |
468
|
|
|
_validate_input_bool(showall, "showall") |
469
|
|
|
|
470
|
|
|
# Handle dictionary defaults |
471
|
|
|
kde_kws = ( |
472
|
|
|
{"alpha": 0.75, "linewidth": 1.5, "bw_adjust": 0.8} |
473
|
|
|
if kde_kws is None |
474
|
|
|
else kde_kws.copy() |
475
|
|
|
) |
476
|
|
|
rug_kws = ( |
477
|
|
|
{"color": "#ff3333", "alpha": 0.15, "lw": 3, "height": 0.075} |
478
|
|
|
if rug_kws is None |
479
|
|
|
else rug_kws.copy() |
480
|
|
|
) |
481
|
|
|
fill_kws = ( |
482
|
|
|
{"color": "#80d4ff", "alpha": 0.2} if fill_kws is None else fill_kws.copy() |
483
|
|
|
) |
484
|
|
|
font_kws = ( |
485
|
|
|
{"color": "#111111", "weight": "normal", "size": 11} |
486
|
|
|
if font_kws is None |
487
|
|
|
else font_kws.copy() |
488
|
|
|
) |
489
|
|
|
|
490
|
|
|
data = pd.DataFrame(data.copy()).dropna(axis=1, how="all") |
491
|
|
|
df = data.copy() |
492
|
|
|
data = data.loc[:, data.nunique() > 2] |
493
|
|
|
if data.shape[0] > 10000: |
494
|
|
|
data = data.sample(n=10000, random_state=408) |
495
|
|
|
print( |
496
|
|
|
"Large dataset detected, using 10000 random samples for the plots. Summary" |
497
|
|
|
" statistics are still based on the entire dataset." |
498
|
|
|
) |
499
|
|
|
cols = list(data.select_dtypes(include=["number"]).columns) |
500
|
|
|
data = data[cols] |
501
|
|
|
|
502
|
|
|
if len(cols) == 0: |
503
|
|
|
print("No columns with numeric data were detected.") |
504
|
|
|
return None |
505
|
|
|
|
506
|
|
|
if len(cols) >= 20 and showall is False: |
507
|
|
|
print( |
508
|
|
|
"Note: The number of non binary numerical features is very large " |
509
|
|
|
f"({len(cols)}), please consider splitting the data. Showing plots for " |
510
|
|
|
"the first 20 numerical features. Override this by setting showall=True." |
511
|
|
|
) |
512
|
|
|
cols = cols[:20] |
513
|
|
|
|
514
|
|
|
g = None |
515
|
|
|
for col in cols: |
516
|
|
|
col_data = data[col].dropna(axis=0) |
517
|
|
|
col_df = df[col].dropna(axis=0) |
518
|
|
|
|
519
|
|
|
g = sns.displot( |
520
|
|
|
col_data, |
521
|
|
|
kind="kde", |
522
|
|
|
rug=True, |
523
|
|
|
height=size, |
524
|
|
|
aspect=5, |
525
|
|
|
legend=False, |
526
|
|
|
rug_kws=rug_kws, |
527
|
|
|
**kde_kws, |
528
|
|
|
) |
529
|
|
|
|
530
|
|
|
# Vertical lines and fill |
531
|
|
|
x, y = g.axes[0, 0].lines[0].get_xydata().T |
532
|
|
|
g.axes[0, 0].fill_between( |
533
|
|
|
x, |
534
|
|
|
y, |
535
|
|
|
where=( |
536
|
|
|
(x >= np.quantile(col_df, fill_range[0])) |
537
|
|
|
& (x <= np.quantile(col_df, fill_range[1])) |
538
|
|
|
), |
539
|
|
|
label=f"{fill_range[0]*100:.1f}% - {fill_range[1]*100:.1f}%", |
540
|
|
|
**fill_kws, |
541
|
|
|
) |
542
|
|
|
|
543
|
|
|
mean = np.mean(col_df) |
544
|
|
|
std = scipy.stats.tstd(col_df) |
545
|
|
|
g.axes[0, 0].vlines( |
546
|
|
|
x=mean, |
547
|
|
|
ymin=0, |
548
|
|
|
ymax=np.interp(mean, x, y), |
549
|
|
|
ls="dotted", |
550
|
|
|
color=mean_color, |
551
|
|
|
lw=2, |
552
|
|
|
label="mean", |
553
|
|
|
) |
554
|
|
|
g.axes[0, 0].vlines( |
555
|
|
|
x=np.median(col_df), |
556
|
|
|
ymin=0, |
557
|
|
|
ymax=np.interp(np.median(col_df), x, y), |
558
|
|
|
ls=":", |
559
|
|
|
color=".3", |
560
|
|
|
label="median", |
561
|
|
|
) |
562
|
|
|
g.axes[0, 0].vlines( |
563
|
|
|
x=[mean - std, mean + std], |
564
|
|
|
ymin=0, |
565
|
|
|
ymax=[np.interp(mean - std, x, y), np.interp(mean + std, x, y)], |
566
|
|
|
ls=":", |
567
|
|
|
color=".5", |
568
|
|
|
label="\u03BC \u00B1 \u03C3", |
569
|
|
|
) |
570
|
|
|
|
571
|
|
|
g.axes[0, 0].set_ylim(0) |
572
|
|
|
g.axes[0, 0].set_xlim( |
573
|
|
|
g.axes[0, 0].get_xlim()[0] - g.axes[0, 0].get_xlim()[1] * 0.05, |
574
|
|
|
g.axes[0, 0].get_xlim()[1] * 1.03, |
575
|
|
|
) |
576
|
|
|
|
577
|
|
|
# Annotations and legend |
578
|
|
|
g.axes[0, 0].text( |
579
|
|
|
0.005, |
580
|
|
|
0.9, |
581
|
|
|
f"Mean: {mean:.2f}", |
582
|
|
|
fontdict=font_kws, |
583
|
|
|
transform=g.axes[0, 0].transAxes, |
584
|
|
|
) |
585
|
|
|
g.axes[0, 0].text( |
586
|
|
|
0.005, |
587
|
|
|
0.7, |
588
|
|
|
f"Std. dev: {std:.2f}", |
589
|
|
|
fontdict=font_kws, |
590
|
|
|
transform=g.axes[0, 0].transAxes, |
591
|
|
|
) |
592
|
|
|
g.axes[0, 0].text( |
593
|
|
|
0.005, |
594
|
|
|
0.5, |
595
|
|
|
f"Skew: {scipy.stats.skew(col_df):.2f}", |
596
|
|
|
fontdict=font_kws, |
597
|
|
|
transform=g.axes[0, 0].transAxes, |
598
|
|
|
) |
599
|
|
|
g.axes[0, 0].text( |
600
|
|
|
0.005, |
601
|
|
|
0.3, |
602
|
|
|
f"Kurtosis: {scipy.stats.kurtosis(col_df):.2f}", # Excess Kurtosis |
603
|
|
|
fontdict=font_kws, |
604
|
|
|
transform=g.axes[0, 0].transAxes, |
605
|
|
|
) |
606
|
|
|
g.axes[0, 0].text( |
607
|
|
|
0.005, |
608
|
|
|
0.1, |
609
|
|
|
f"Count: {len(col_df)}", |
610
|
|
|
fontdict=font_kws, |
611
|
|
|
transform=g.axes[0, 0].transAxes, |
612
|
|
|
) |
613
|
|
|
g.axes[0, 0].legend(loc="upper right") |
614
|
|
|
|
615
|
|
|
return g.axes[0, 0] |
616
|
|
|
|
617
|
|
|
|
618
|
|
|
# Missing value plot |
619
|
|
|
def missingval_plot( |
620
|
|
|
data: pd.DataFrame, |
621
|
|
|
cmap: str = "PuBuGn", |
622
|
|
|
figsize: Tuple = (20, 20), |
623
|
|
|
sort: bool = False, |
624
|
|
|
spine_color: str = "#EEEEEE", |
625
|
|
|
): |
626
|
|
|
""" Two-dimensional visualization of the missing values in a dataset. |
627
|
|
|
|
628
|
|
|
Parameters |
629
|
|
|
---------- |
630
|
|
|
data : pd.DataFrame |
631
|
|
|
2D dataset that can be coerced into Pandas DataFrame. If a Pandas DataFrame \ |
632
|
|
|
is provided, the index/column information is used to label the plots |
633
|
|
|
cmap : str, optional |
634
|
|
|
Any valid colormap can be used. E.g. "Greys", "RdPu". More information can be \ |
635
|
|
|
found in the matplotlib documentation, by default "PuBuGn" |
636
|
|
|
figsize : Tuple, optional |
637
|
|
|
Use to control the figure size, by default (20, 20) |
638
|
|
|
sort : bool, optional |
639
|
|
|
Sort columns based on missing values in descending order and drop columns \ |
640
|
|
|
without any missing values, by default False |
641
|
|
|
spine_color : str, optional |
642
|
|
|
Set to "None" to hide the spines on all plots or use any valid matplotlib \ |
643
|
|
|
color argument, by default "#EEEEEE" |
644
|
|
|
|
645
|
|
|
Returns |
646
|
|
|
------- |
647
|
|
|
GridSpec |
648
|
|
|
gs: Figure with array of Axes objects |
649
|
|
|
""" |
650
|
|
|
|
651
|
|
|
# Validate Inputs |
652
|
|
|
_validate_input_bool(sort, "sort") |
653
|
|
|
|
654
|
|
|
data = pd.DataFrame(data) |
655
|
|
|
|
656
|
|
|
if sort: |
657
|
|
|
mv_cols_sorted = data.isna().sum(axis=0).sort_values(ascending=False) |
658
|
|
|
final_cols = ( |
659
|
|
|
mv_cols_sorted.drop( |
660
|
|
|
mv_cols_sorted[mv_cols_sorted.values == 0].keys().tolist() |
661
|
|
|
) |
662
|
|
|
.keys() |
663
|
|
|
.tolist() |
664
|
|
|
) |
665
|
|
|
data = data[final_cols] |
666
|
|
|
print("Displaying only columns with missing values.") |
667
|
|
|
|
668
|
|
|
# Identify missing values |
669
|
|
|
mv_total, mv_rows, mv_cols, _, mv_cols_ratio = _missing_vals(data).values() |
670
|
|
|
total_datapoints = data.shape[0] * data.shape[1] |
671
|
|
|
|
672
|
|
|
if mv_total == 0: |
673
|
|
|
print("No missing values found in the dataset.") |
674
|
|
|
return None |
675
|
|
|
else: |
676
|
|
|
# Create figure and axes |
677
|
|
|
fig = plt.figure(figsize=figsize) |
678
|
|
|
gs = fig.add_gridspec(nrows=6, ncols=6, left=0.1, wspace=0.05) |
679
|
|
|
ax1 = fig.add_subplot(gs[:1, :5]) |
680
|
|
|
ax2 = fig.add_subplot(gs[1:, :5]) |
681
|
|
|
ax3 = fig.add_subplot(gs[:1, 5:]) |
682
|
|
|
ax4 = fig.add_subplot(gs[1:, 5:]) |
683
|
|
|
|
684
|
|
|
# ax1 - Barplot |
685
|
|
|
colors = plt.get_cmap(cmap)(mv_cols / np.max(mv_cols)) # color bars by height |
686
|
|
|
ax1.bar(range(len(mv_cols)), np.round((mv_cols_ratio) * 100, 2), color=colors) |
687
|
|
|
ax1.get_xaxis().set_visible(False) |
688
|
|
|
ax1.set(frame_on=False, xlim=(-0.5, len(mv_cols) - 0.5)) |
689
|
|
|
ax1.set_ylim(0, np.max(mv_cols_ratio) * 100) |
690
|
|
|
ax1.grid(linestyle=":", linewidth=1) |
691
|
|
|
ax1.yaxis.set_major_formatter(ticker.PercentFormatter(decimals=0)) |
692
|
|
|
ax1.tick_params(axis="y", colors="#111111", length=1) |
693
|
|
|
|
694
|
|
|
# annotate values on top of the bars |
695
|
|
|
for rect, label in zip(ax1.patches, mv_cols): |
696
|
|
|
height = rect.get_height() |
697
|
|
|
ax1.text( |
698
|
|
|
0.1 + rect.get_x() + rect.get_width() / 2, |
699
|
|
|
height + 0.5, |
700
|
|
|
label, |
701
|
|
|
ha="center", |
702
|
|
|
va="bottom", |
703
|
|
|
rotation="90", |
704
|
|
|
alpha=0.5, |
705
|
|
|
fontsize="11", |
706
|
|
|
) |
707
|
|
|
|
708
|
|
|
ax1.set_frame_on(True) |
709
|
|
|
for _, spine in ax1.spines.items(): |
710
|
|
|
spine.set_visible(True) |
711
|
|
|
spine.set_color(spine_color) |
712
|
|
|
ax1.spines["top"].set_color(None) |
713
|
|
|
|
714
|
|
|
# ax2 - Heatmap |
715
|
|
|
sns.heatmap(data.isna(), cbar=False, cmap="binary", ax=ax2) |
716
|
|
|
ax2.set_yticks(np.round(ax2.get_yticks()[0::5], -1)) |
717
|
|
|
ax2.set_yticklabels(ax2.get_yticks()) |
718
|
|
|
ax2.set_xticklabels( |
719
|
|
|
ax2.get_xticklabels(), |
720
|
|
|
horizontalalignment="center", |
721
|
|
|
fontweight="light", |
722
|
|
|
fontsize="12", |
723
|
|
|
) |
724
|
|
|
ax2.tick_params(length=1, colors="#111111") |
725
|
|
|
for _, spine in ax2.spines.items(): |
726
|
|
|
spine.set_visible(True) |
727
|
|
|
spine.set_color(spine_color) |
728
|
|
|
|
729
|
|
|
# ax3 - Summary |
730
|
|
|
fontax3 = {"color": "#111111", "weight": "normal", "size": 14} |
731
|
|
|
ax3.get_xaxis().set_visible(False) |
732
|
|
|
ax3.get_yaxis().set_visible(False) |
733
|
|
|
ax3.set(frame_on=False) |
734
|
|
|
|
735
|
|
|
ax3.text( |
736
|
|
|
0.025, |
737
|
|
|
0.875, |
738
|
|
|
f"Total: {np.round(total_datapoints/1000,1)}K", |
739
|
|
|
transform=ax3.transAxes, |
740
|
|
|
fontdict=fontax3, |
741
|
|
|
) |
742
|
|
|
ax3.text( |
743
|
|
|
0.025, |
744
|
|
|
0.675, |
745
|
|
|
f"Missing: {np.round(mv_total/1000,1)}K", |
746
|
|
|
transform=ax3.transAxes, |
747
|
|
|
fontdict=fontax3, |
748
|
|
|
) |
749
|
|
|
ax3.text( |
750
|
|
|
0.025, |
751
|
|
|
0.475, |
752
|
|
|
f"Relative: {np.round(mv_total/total_datapoints*100,1)}%", |
753
|
|
|
transform=ax3.transAxes, |
754
|
|
|
fontdict=fontax3, |
755
|
|
|
) |
756
|
|
|
ax3.text( |
757
|
|
|
0.025, |
758
|
|
|
0.275, |
759
|
|
|
f"Max-col: {np.round(mv_cols.max()/data.shape[0]*100)}%", |
760
|
|
|
transform=ax3.transAxes, |
761
|
|
|
fontdict=fontax3, |
762
|
|
|
) |
763
|
|
|
ax3.text( |
764
|
|
|
0.025, |
765
|
|
|
0.075, |
766
|
|
|
f"Max-row: {np.round(mv_rows.max()/data.shape[1]*100)}%", |
767
|
|
|
transform=ax3.transAxes, |
768
|
|
|
fontdict=fontax3, |
769
|
|
|
) |
770
|
|
|
|
771
|
|
|
# ax4 - Scatter plot |
772
|
|
|
ax4.get_yaxis().set_visible(False) |
773
|
|
|
for _, spine in ax4.spines.items(): |
774
|
|
|
spine.set_color(spine_color) |
775
|
|
|
ax4.tick_params(axis="x", colors="#111111", length=1) |
776
|
|
|
|
777
|
|
|
ax4.scatter( |
778
|
|
|
mv_rows, |
779
|
|
|
range(len(mv_rows)), |
780
|
|
|
s=mv_rows, |
781
|
|
|
c=mv_rows, |
782
|
|
|
cmap=cmap, |
783
|
|
|
marker=".", |
784
|
|
|
vmin=1, |
785
|
|
|
) |
786
|
|
|
ax4.set_ylim((0, len(mv_rows))[::-1]) # limit and invert y-axis |
787
|
|
|
ax4.set_xlim(0, max(mv_rows) + 0.5) |
788
|
|
|
ax4.grid(linestyle=":", linewidth=1) |
789
|
|
|
|
790
|
|
|
gs.figure.suptitle( |
791
|
|
|
"Missing value plot", x=0.45, y=0.94, fontsize=18, color="#111111" |
792
|
|
|
) |
793
|
|
|
|
794
|
|
|
return gs |
795
|
|
|
|