Passed
Push — master ( 170db5...8af2aa )
by Shlomi
02:43 queued 58s
created

ethically.dataset.compas.COMPASDataset.__init__()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
__all__ = ['COMPASDataset']
2
3
import numpy as np
4
import pandas as pd
5
from pkg_resources import resource_filename
6
7
from ethically.dataset.core import Dataset
8
9
10
COMPAS_PATH = resource_filename(__name__,
11
                                'compas-scores-two-years.csv')
12
13
14
class COMPASDataset(Dataset):
15
    """ProPublica Recidivism/COMPAS Dataset.
16
17
    See :class:`~ethically.dataset.Dataset` for a description of
18
    the arguments and attributes.
19
20
    References:
21
        https://github.com/propublica/compas-analysis
22
23
    """
24
25
    def __init__(self):
26
        super().__init__(target='is_recid',
27
                         sensitive_attributes=['race', 'sex'],
28
                         prediction=['y_pred', 'score_factor',
29
                                     'score_text'])
30
31
    def _load_data(self):
32
        return pd.read_csv(COMPAS_PATH)
33
34
    def _preprocess(self):
35
        """Perform the same preprocessing as the original analysis.
36
37
        https://github.com/propublica/compas-analysis/blob/master/Compas%20Analysis.ipynb
38
        """
39
40
        self.df = self.df[(self.df['days_b_screening_arrest'] <= 30)
41
                          & (self.df['days_b_screening_arrest'] >= -30)
42
                          & (self.df['is_recid'] != -1)
43
                          & (self.df['c_charge_degree'] != 'O')
44
                          & (self.df['score_text'] != 'N/A')]
45
46
        self.df['c_jail_out'] = pd.to_datetime(self.df['c_jail_out'])
47
        self.df['c_jail_in'] = pd.to_datetime(self.df['c_jail_in'])
48
        self.df['length_of_stay'] = (self.df['c_jail_out']
49
                                     - self.df['c_jail_in'])
50
51
        self.df['score_factor'] = np.where(self.df['score_text']
52
                                           != 'Low',
53
                                           'HighScore', 'LowScore')
54
        self.df['y_pred'] = (self.df['score_factor'] == 'HighScore')
55
56
    def _validate(self):
57
        # pylint: disable=line-too-long
58
        super()._validate()
59
60
        assert len(self.df) == 6172, 'the number of rows should be 6172,'\
61
                                     ' but it is {}.'.format(len(self.df))
62
        assert len(self.df.columns) == 56, 'the number of columns should be 56,'\
63
                                           ' but it is {}.'.format(len(self.df.columns))
64