Passed
Push — master ( 170db5...8af2aa )
by Shlomi
02:43 queued 58s
created

ethically.dataset.core.Dataset._preprocess()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import abc
2
3
4
class Dataset(abc.ABC):
5
    """Base class for datasets.
6
7
    Attributes
8
        - `df` - :class:`pandas.DataFrame` that holds the actual data.
9
10
        - `target` - Column name of the variable to predict
11
                    (ground truth)
12
13
        - `sensitive_attributes` - Column name of the
14
                                sensitive attributes
15
16
        - `prediction` - Columns name of the
17
                        prediction (optional)
18
19
    """
20
21
    @abc.abstractmethod
22
    def __init__(self, target, sensitive_attributes, prediction=None):
23
        """Load, preprocess and validate the dataset.
24
25
        :param target: Column name of the variable
26
                    to predict (ground truth)
27
        :param sensitive_attributes: Column name of the
28
                                    sensitive attributes
29
        :param prediction: Columns name of the
30
                           prediction (optional)
31
        :type target: str
32
        :type sensitive_attributes: list
33
        :type prediction: str
34
        """
35
36
        self.df = self._load_data()
37
38
        self._preprocess()
39
40
        self._name = self.__doc__.splitlines()[0]
41
42
        self.target = target
43
        self.sensitive_attributes = sensitive_attributes
44
        self.prediction = prediction
45
46
        self._validate()
47
48
    def __str__(self):
49
        return ('<{} {} rows, {} columns'
50
                ' in which {{{}}} are sensitive attributes>'
51
                .format(self._name,
52
                        len(self.df),
53
                        len(self.df.columns),
54
                        ', '.join(self.sensitive_attributes)))
55
56
    @abc.abstractmethod
57
    def _load_data(self):
58
        pass
59
60
    @abc.abstractmethod
61
    def _preprocess(self):
62
        pass
63
64
    @abc.abstractmethod
65
    def _validate(self):
66
        # pylint: disable=line-too-long
67
68
        assert self.target in self.df.columns,\
69
            ('the target label \'{}\' should be in the columns'
70
             .format(self.target))
71
72
        assert all(attr in self.df.columns
73
                   for attr in self.sensitive_attributes),\
74
            ('the sensitive attributes {{{}}} should be in the columns'
75
             .format(','.join(attr for attr in self.sensitive_attributes
76
                              if attr not in self.df.columns)))
77
78
        # assert all(attr in SENSITIVE_ATTRIBUTES
79
        #           for attr in self.sensitive_attributes),\
80
        # ('the sensitive attributes {} can be only from {}.'  # noqa
81
        #  .format(self.sensitive_attributes, SENSITIVE_ATTRIBUTES))
82