Passed
Push — mpeta ( 62640f...eed483 )
by Konstantinos
01:41
created

OneHotStringEncoder._encode_none()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 3
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
from abc import ABC, abstractmethod
2
import attr
3
from so_magic.utils import SubclassRegistry
4
5
6
class EncoderInterface(ABC):
7
    @abstractmethod
8
    def encode(self, *args, **kwargs):
9
        raise NotImplementedError
10
11
12
class EncoderFactoryType(type):
13
14
    @classmethod
15
    def create(mcs, *args, **kwargs) -> EncoderInterface:
16
        raise NotImplementedError
17
18
19
# class NominalVariableEncoderFactory:
20
#     @classmethod
21
#     def create(cls, *args, **kwargs) -> EncoderInterface:
22
23
24
25
@attr.s(slots=True)
26
class NominalAttributeEncoder(EncoderInterface, ABC):
27
    """Encode the observations of a categorical nominal variable.
28
29
    The client code can supply the possible values for the nominal variable, if known a priori.
30
    The possible values are stored in the 'values_set' attribute/property. If they are not supplied
31
    they should be computed at runtime (when running the encode method).
32
33
    It also defines and stores the string identifiers for each column produced in the 'columns attribute/property.
34
35
    Args:
36
        values_set (list): the possible values of the nominal variable observations, if known a priori
37
    """
38
    values_set: list = attr.ib(default=attr.Factory(list))
39
    columns: list = attr.ib(init=False, default=attr.Factory(list))
40
41
42
class EncoderFactoryClassRegistry(metaclass=SubclassRegistry): pass
43
44
from functools import reduce
45
import pandas as pd
46
47
48
@EncoderFactoryClassRegistry.register_as_subclass('nominal_list')
49
class OneHotListEncoder(EncoderInterface):
50
    binary_transformer = {True: 1.0, False: 0.0}
51
    column_name_joiner = '_'
52
    def __init__(self, *args, **kwargs) -> None:
53
        pass
54
    def encode(self, *args, **kwargs):
55
        datapoints = args[0]
56
        attribute = args[1]
57
        print('ATRTRBUTE', attribute)
58
        print('STR', str(attribute))
59
        cc = [_ for _ in datapoints.observations[str(attribute)]] 
60
        print('LEN1', len(cc))
61
        c = [_ for _ in cc if isinstance(_, list)]
62
        print('LEN2', len(c))
63
        self.values_set = reduce(lambda i, j: set(i).union(set(j)),
64
                                 c)
65
        self.columns = sorted([f'{str(attribute)}{self.column_name_joiner}{x}' for x in self.values_set])
66
        return pd.DataFrame([self._yield_vector(datarow, str(attribute)) for index, datarow in datapoints.iterrows()],
67
                            columns=self.columns)
68
69
    def _yield_vector(self, datarow, attribute):
70
        decision = {True: self._encode, False: self._encode_none}
71
        return decision[isinstance(datarow[str(attribute)], list)](datarow, str(attribute))
72
73
    def _encode(self, datarow, attribute):
74
        return [self.binary_transformer[column in datarow[str(attribute)]] for column in sorted(self.values_set)]
75
76
    def _encode_none(self, _datarow, _attribute):
77
        return [0.0] * len(self.values_set)
78
    
79
    def get_feature_names(self):
80
        return self.columns
81
82
83
84
@EncoderFactoryClassRegistry.register_as_subclass('nominal_str')
85
class OneHotStringEncoder(EncoderInterface):
86
    binary_transformer = {True: 1.0, False: 0.0}
87
    column_name_joiner = '_'
88
    def __init__(self, *args, **kwargs) -> None:
89
        pass
90
    def encode(self, *args, **kwargs):
91
        datapoints = args[0]
92
        attribute = args[1]
93
        print('ATRTRBUTE', attribute)
94
        print('STR', str(attribute))
95
        c = [x for x in datapoints.observations[str(attribute)] if isinstance(x, str)]
96
        self.values_set = {value for value in c}
97
        self.columns = sorted([f'{str(attribute)}{self.column_name_joiner}{x}' for x in self.values_set])
98
        return pd.DataFrame([self._yield_vector(datarow, str(attribute)) for index, datarow in datapoints.iterrows()],
99
                            columns=self.columns)
100
101
    def _yield_vector(self, datarow, attribute):
102
        decision = {True: self._encode, False: self._encode_none}
103
        return decision[isinstance(datarow[str(attribute)], str)](datarow, str(attribute))
104
105
    def _encode(self, datarow, attribute):
106
        return [self.binary_transformer[variable_value == datarow[str(attribute)]] for variable_value in sorted(self.values_set)]
107
108
    def _encode_none(self, _datarow, _attribute):
109
        return [0.0] * len(self.values_set)
110
111
    def get_feature_names(self):
112
        return self.columns
113
114
115
@attr.s
116
class EncoderFactory:
117
    encoder_factory_classes_registry = attr.ib(default=attr.Factory(lambda: EncoderFactoryClassRegistry))
118
    def create(self, datapoints, variable, scheme='auto'):
119
        key = self.get_key(variable)
120
        return self.encoder_factory_classes_registry.create(key, datapoints, variable, scheme='auto')
121
122
    def get_key(self, variable):
123
        return f'{str(variable.type).lower()}_{str(variable.data_type.__name__)}'
124
125
126
@attr.s
127
class MagicEncoderFactory:
128
    encoder_factory = attr.ib(init=False, default=attr.Factory(lambda: EncoderFactory()))
129
130
    def create(self, datapoints, variable, scheme='auto'):
131
        return self.encoder_factory.create(datapoints, variable, scheme='auto')
132