1
|
|
|
import abc |
2
|
|
|
|
3
|
|
|
|
4
|
|
|
class Dataset(abc.ABC): |
5
|
|
|
"""Base class for datasets. |
6
|
|
|
|
7
|
|
|
Attributes |
8
|
|
|
- `df` - :class:`pandas.DataFrame` that holds the actual data. |
9
|
|
|
|
10
|
|
|
- `target` - Column name of the variable to predict |
11
|
|
|
(ground truth) |
12
|
|
|
|
13
|
|
|
- `sensitive_attributes` - Column name of the |
14
|
|
|
sensitive attributes |
15
|
|
|
|
16
|
|
|
- `prediction` - Columns name of the |
17
|
|
|
prediction (optional) |
18
|
|
|
|
19
|
|
|
""" |
20
|
|
|
|
21
|
|
|
@abc.abstractmethod |
22
|
|
|
def __init__(self, target, sensitive_attributes, prediction=None): |
23
|
|
|
"""Load, preprocess and validate the dataset. |
24
|
|
|
|
25
|
|
|
:param target: Column name of the variable |
26
|
|
|
to predict (ground truth) |
27
|
|
|
:param sensitive_attributes: Column name of the |
28
|
|
|
sensitive attributes |
29
|
|
|
:param prediction: Columns name of the |
30
|
|
|
prediction (optional) |
31
|
|
|
:type target: str |
32
|
|
|
:type sensitive_attributes: list |
33
|
|
|
:type prediction: str |
34
|
|
|
""" |
35
|
|
|
|
36
|
|
|
self.df = self._load_data() |
37
|
|
|
|
38
|
|
|
self._preprocess() |
39
|
|
|
|
40
|
|
|
self._name = self.__doc__.splitlines()[0] |
41
|
|
|
|
42
|
|
|
self.target = target |
43
|
|
|
self.sensitive_attributes = sensitive_attributes |
44
|
|
|
self.prediction = prediction |
45
|
|
|
|
46
|
|
|
self._validate() |
47
|
|
|
|
48
|
|
|
def __str__(self): |
49
|
|
|
return ('<{} {} rows, {} columns' |
50
|
|
|
' in which {{{}}} are sensitive attributes>' |
51
|
|
|
.format(self._name, |
52
|
|
|
len(self.df), |
53
|
|
|
len(self.df.columns), |
54
|
|
|
', '.join(self.sensitive_attributes))) |
55
|
|
|
|
56
|
|
|
@abc.abstractmethod |
57
|
|
|
def _load_data(self): |
58
|
|
|
pass |
59
|
|
|
|
60
|
|
|
@abc.abstractmethod |
61
|
|
|
def _preprocess(self): |
62
|
|
|
pass |
63
|
|
|
|
64
|
|
|
@abc.abstractmethod |
65
|
|
|
def _validate(self): |
66
|
|
|
# pylint: disable=line-too-long |
67
|
|
|
|
68
|
|
|
assert self.target in self.df.columns,\ |
69
|
|
|
('the target label \'{}\' should be in the columns' |
70
|
|
|
.format(self.target)) |
71
|
|
|
|
72
|
|
|
assert all(attr in self.df.columns |
73
|
|
|
for attr in self.sensitive_attributes),\ |
74
|
|
|
('the sensitive attributes {{{}}} should be in the columns' |
75
|
|
|
.format(','.join(attr for attr in self.sensitive_attributes |
76
|
|
|
if attr not in self.df.columns))) |
77
|
|
|
|
78
|
|
|
# assert all(attr in SENSITIVE_ATTRIBUTES |
79
|
|
|
# for attr in self.sensitive_attributes),\ |
80
|
|
|
# ('the sensitive attributes {} can be only from {}.' # noqa |
81
|
|
|
# .format(self.sensitive_attributes, SENSITIVE_ATTRIBUTES)) |
82
|
|
|
|