RegularizersFactory.__init__()   F
last analyzed

Complexity

Conditions 22

Size

Total Lines 63
Code Lines 58

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 22
eloc 58
nop 2
dl 0
loc 63
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like patm.modeling.regularization.regularizers_factory.RegularizersFactory.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from os import path
2
import artm
3
from collections import OrderedDict
4
from configparser import ConfigParser
5
6
from topic_modeling_toolkit.patm.utils import cfg2model_settings
7
from topic_modeling_toolkit.patm.definitions import DEFAULT_CLASS_NAME, IDEOLOGY_CLASS_NAME  # this is the name of the default modality. it is irrelevant to class lebels or document lcassification
8
9
from .regularizers import ArtmRegularizerWrapper
10
11
import logging
12
logger = logging.getLogger(__name__)
13
14
import attr
15
16
my_dir = path.dirname(path.realpath(__file__))
17
REGULARIZERS_CFG = path.join(my_dir, './regularizers.cfg')
18
19
20
set1 = ('tau', 'gamma', 'class_ids', 'topic_names')
21
set2 = ('tau', 'topic_names', 'alpha_iter', 'doc_titles', 'doc_topic_coef')
22
decorrelation = ('tau', 'gamma', 'class_ids', 'topic_names', 'topic_pairs')
23
regularizer2parameters = {
24
    'smooth-phi': set1,
25
    'sparse-phi': set1,
26
    'smooth-theta': set2,
27
    'sparse-theta': set2,
28
    'decorrelate-phi': decorrelation,
29
    # 'decorrelate-phi-def': decorrelation,  # default token modality
30
    # 'decorrelate-phi-class': decorrelation,  # doc class modality
31
    # 'decorrelate-phi-domain': decorrelation,  # all modalities
32
    # 'decorrelate-phi-background': decorrelation,  # all modalities
33
    'label-regularization-phi': set1,
34
    # 'kl-function-info': ('function_type', 'power_value'),
35
    # 'specified-sparse-phi': ('tau', 'gamma', 'topic_names', 'class_id', 'num_max_elements', 'probability_threshold', 'sparse_by_column'),
36
    'improve-coherence': ('tau', 'gamma', 'class_ids', 'topic_names'),
37
    # 'smooth-ptdw': ('tau', 'topic_names', 'alpha_iter'),
38
    # 'topic-selection': ('tau', 'topic_names', 'alpha_iter')
39
    # the reasonable parameters to consider experiment with setting to different values and to also change between training cycles
40
}
41
42
43
# parameters that can be dynamically updated during training ie tau coefficients for sparsing regularizers should increase in absolute values (start below zero and decrease )
44
supported_dynamic_parameters = ('tau', 'gamma')
45
46
REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH = dict([(k, [_ for _ in v if _ in supported_dynamic_parameters]) for k, v in regularizer2parameters.items()])
47
48
49
parameter_name2encoder = {
50
    'tau': str, # the coefficient of regularization for this regularizer
51
    'start': str, # the number of iterations to initially keep the regularizer turned off
52
    'gamma': float, # the coefficient of relative regularization for this regularizer
53
    'alpha_iter': str, # list of additional coefficients of regularization on each iteration over document. Should have length equal to model.num_document_passes (of an artm.ARTM model object)
54
    # 'class_ids': list, # list of class_ids or single class_id to regularize, will regularize all classes if empty or None [ONLY for ImproveCoherencePhiRegularizer: dictionary should contain pairwise tokens co-occurrence info]
55
    # 'topic_names': list, # specific topics to target for applying regularization, Targets all if None
56
    # 'topic_pairs': dict, # key=topic_name, value=dict: pairwise topic decorrelation coefficients, if None all values equal to 1.0
57
    # 'function_type': str, # the type of function, 'log' (logarithm) or 'pol' (polynomial)
58
    # 'power_value': float, # the float power of polynomial, ignored if 'function_type' = 'log'
59
    # 'doc_titles': list, # of strings: list of titles of documents to be processed by this regularizer. Default empty value means processing of all documents. User should guarantee the existence and correctness of document titles in batches (e.g. in src files with data, like WV).
60
    # 'doc_topic_coef': list, # (list of floats or list of list of floats): of floats len=nb_topics or of lists of floats len=len(doc_titles) and len(inner_list)=_nb_topics: Two cases: 1) list of floats with length equal to num of topics. Means additional multiplier in M-step formula besides alpha and tau, unique for each topic, but general for all processing documents. 2) list of lists of floats with outer list length equal to length of doc_titles, and each inner list length equal to num of topics. Means case 1 with unique list of additional multipliers for each document from doc_titles. Other documents will not be regularized according to description of doc_titles parameter. Note, that doc_topic_coef and topic_names are both using.
61
    # 'num_max_elements': int, # number of elements to save in row/column for the artm.SpecifiedSparsePhiRegularizer
62
    # 'probability_threshold': float, # if m elements in row/column sum into value >= probability_threshold, m < n => only these elements would be saved. Value should be in (0, 1), default=None
63
    # 'sparse_by_columns': bool, # find max elements in column or row
64
}
65
66
67
class RegularizersFactory:
68
    """Supports construction of Smoothing, Sparsing, Label, Decorrelating, Coherence-improving regularizers. Currently, it
69
        resorts to supplying the internal artm.Dictionary to the regularizers' constructors whenever possible."""
70
71
    def __init__(self, dictionary):
72
        """
73
        :param artm.Dictionary dictionary: this object shall be passed as regularizers' constructors argument when possible
74
        """
75
        # :param str regularizers_initialization_parameters: this file contains parameters and values to use as defaults for when initializing regularizer object
76
        self._dictionary = dictionary
77
        self._regs_data = None
78
        self._reg_settings = {}
79
        self._back_t, self._domain_t = [], []
80
        self._regularizer_type2constructor = \
81
            {'smooth-phi': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._back_t, [DEFAULT_CLASS_NAME]),
82
             'smooth-theta': lambda x: ArtmRegularizerWrapper.create('smooth-theta', x, self._back_t),
83
             'sparse-phi': lambda x: ArtmRegularizerWrapper.create('sparse-phi', x, self._domain_t, [DEFAULT_CLASS_NAME]),
84
             'sparse-theta': lambda x: ArtmRegularizerWrapper.create('sparse-theta', x, self._domain_t),
85
             'smooth-phi-dom-cls': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._domain_t, [IDEOLOGY_CLASS_NAME]),
86
             'smooth-phi-bac-cls': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._back_t, [IDEOLOGY_CLASS_NAME]),
87
             'smooth-phi-cls': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._back_t + self._domain_t, [IDEOLOGY_CLASS_NAME]),
88
             'label-regularization-phi-dom-cls': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
89
                                                                                 self._domain_t,
90
                                                                                 dictionary=self._dictionary,
91
                                                                                 class_ids=IDEOLOGY_CLASS_NAME),
92
             'decorrelate-phi-dom-def': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi', x, self._domain_t,
93
                                                                                class_ids=DEFAULT_CLASS_NAME),
94
             'label-regularization-phi-dom-all': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, self._domain_t,
95
                                                                                 dictionary=self._dictionary,
96
                                                                                 class_ids=None), # targets all classes, since no CLASS_LABELS list is given
97
             'label-regularization-phi-bac-all': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
98
                                                                                         self._back_t,
99
                                                                                         dictionary=self._dictionary,
100
                                                                                         class_ids=None),
101
             'label-regularization-phi-dom-def': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
102
                                                                                 self._domain_t,
103
                                                                                 dictionary=self._dictionary,
104
                                                                                 class_ids=[DEFAULT_CLASS_NAME]),
105
             'label-regularization-phi-bac-def': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
106
                                                                                         self._back_t,
107
                                                                                         dictionary=self._dictionary,
108
                                                                                         class_ids=DEFAULT_CLASS_NAME),
109
110
             'label-regularization-phi-bac-cls': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
111
                                                                                         self._back_t,
112
                                                                                         dictionary=self._dictionary,
113
                                                                                         class_ids=IDEOLOGY_CLASS_NAME),
114
             'label-regularization-phi-all': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
115
                                                                                 self._domain_t + self._back_t,
116
                                                                                 dictionary=self._dictionary,
117
                                                                                 class_ids=None),
118
             'label-regularization-phi-def': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
119
                                                                                     self._domain_t + self._back_t,
120
                                                                                     dictionary=self._dictionary,
121
                                                                                     class_ids=DEFAULT_CLASS_NAME),
122
             'label-regularization-phi-cls': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x,
123
                                                                                     self._domain_t + self._back_t,
124
                                                                                     dictionary=self._dictionary,
125
                                                                                     class_ids=IDEOLOGY_CLASS_NAME),
126
             'decorrelate-phi-def': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi-def', x, self._domain_t, class_ids=DEFAULT_CLASS_NAME),
127
             'decorrelate-phi-class': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi-class', x, self._domain_t,
128
                                                                            class_ids=IDEOLOGY_CLASS_NAME),
129
             'decorrelate-phi-background': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi', x, self._back_t,
130
                                                                               class_ids=None),
131
             'improve-coherence': lambda x: ArtmRegularizerWrapper.create('improve-coherence', x, self._domain_t, self._dictionary,
132
                                                                          class_ids=DEFAULT_CLASS_NAME)}
133
        logger.info("Initialized RegularizersFactory with artm.Dictionary '{}'".format(self._dictionary.name))
134
135
    @property
136
    def regs_data(self):
137
        return self._regs_data
138
139
    @regs_data.setter
140
    def regs_data(self, regs_data):
141
        """
142
        :param RegularizersData regs_data:
143
        """
144
        self._regs_data = regs_data
145
        self._back_t, self._domain_t = regs_data.background_topics, regs_data.domain_topics
146
        self._reg_settings = regs_data.regularizers_parameters
147
148
    def create_reg_wrappers(self, reg_type2name, background_topics, domain_topics, reg_cfg=None):
149
        """
150
        Creates a dict: each key is a regularizer type (identical to one of the '_regularizers_section_name2constructor' hash'\n
151
        :param str or list or dict reg_type2name: indicates which regularizers should be active; eg keys of the 'regularizers' section of the train.cfg
152
        - If type(reg_type2name) == str: reg_type2name is a file path to a cfg formated file that has a 'regularizers' section indicating the active regularization components.\n
153
        - If type(reg_type2name) == list: reg_type2name is a list of tuples with each 1st element being the regularizer type (eg 'smooth-phi', 'decorrelate-phi-domain') and each 2nd element being the regularizer unique name.\n
154
        - If type(reg_type2name) == dict: reg_type2name maps regularizer types to names. regularizer types and regularizer names
155
        :param list background_topics: a list of the 'background' topic names. Can be empty.
156
        :param list domain_topics: a list of the 'domain' topic names. Can be empty.
157
        :param str or dict reg_cfg: contains the values for initializing the regularizers' parameters (eg tau) with. If None then the default file is used
158
        - If type(reg_cfg) == str: reg_cfg is a file path to a cfg formated file that has as sections regularizer_types with their keys being initialization parameters\n
159
        - If type(reg_cfg) == dict: reg_cfg maps regularizer types to parameters dict.
160
        :rtype: RegularizersFactory
161
        """
162
        if not reg_cfg:
163
            reg_cfg = REGULARIZERS_CFG
164
        self._regs_data = RegularizersData(background_topics, domain_topics, reg_type2name, reg_cfg)
165
        self._back_t, self._domain_t = background_topics, domain_topics
166
        self._reg_settings = self._regs_data.regularizers_parameters
167
        logger.info("Active regs: {}".format(self._regs_data.regs_hash))
168
        logger.info("Regs default inits cfg file: {}".format(reg_cfg))
169
        logger.info("Reg settings: {}".format(self._regs_data.regularizers_parameters))
170
        return [self.construct_reg_wrapper(reg, params) for reg, params in sorted(self._regs_data.regularizers_parameters.items(), key=lambda x: x[0])]
171
172
    # def _create_reg_wrappers(self):
173
    #     """
174
    #     Call this method to create all possible regularization components for the model.\n
175
    #      keys; each value is a dictionry of the corresponding's regularizer's initialization parameter; ie:
176
    #       {\n
177
    #       'sparse-phi': {'name': 'spp', 'start': 5, 'tau': 'linear_-0.5_-5'},\n
178
    #       'smooth-phi': {'name': 'smp', 'tau': 1},\n
179
    #       'sparse-theta': {'name': 'spt', 'start': 3, 'tau': linear_-0.2_-6, 'alpha_iter': 1},\n
180
    #       'smooth-theta': {'name': 'smt', 'tau': 1, 'alpha_iter': linear_0.5_1}\n
181
    #       }
182
    #     :return: the constructed regularizers; objects of type ArtmRegularizerWrapper
183
    #     :rtype: list
184
    #     """
185
    #
186
        # return list(filter(None, map(lambda x: self.construct_reg_wrapper(x[0], x[1]), sorted(self._reg_settings.items(), key=lambda y: y[0]))))
187
188
    def construct_reg_wrapper(self, reg_type, settings):
189
        """
190
        :param str reg_type: the regularizer's unique definition, based on reg_type, topics targeted, modality targeted
191
        :param dict settings: key, values pairs to initialize the regularizer parameters. Must contain 'name' key
192
        :return: the regularizer's wrapper object reference
193
        :rtype: ArtmRegularizerWrapper
194
        """
195
        if reg_type not in self._regularizer_type2constructor:
196
            raise RuntimeError("Requested to create '{}' regularizer, which is not supported".format(reg_type))
197
        if (self._back_t is None or len(self._back_t) == 0) and reg_type.startswith('smooth'):
198
            logger.warning("Requested to create '{}' regularizer, which normally targets 'bakground' topicts, but there are "
199
                          "not distinct 'background' topics defined. The constructed regularizer will target all topics instead.".format(reg_type))
200
        # manually insert the 'long_type' string in the settings hash to use it as the truly unique 'type' of a regularizer
201
        return self._regularizer_type2constructor[reg_type](dict(settings, **{'long-type': reg_type}))
202
203
204
def _parse_active_regs(regs):
205
    reg_settings = {'OrderedDict': lambda x: x.items(),
206
                                                'dict': lambda x: x.items(),
207
                                                'str': lambda x: cfg2model_settings(x)['regularizers'],
208
                                                'list': lambda x: dict(_abbreviation(element) for element in x)}
209
    return dict(reg_settings[type(regs).__name__](regs))  # reg-def, reg-name tuples in a list
210
211
def _abbreviation(reg_type):
212
    if type(reg_type) == str:
213
        r = reg_type.split('-')
214
        if len(r) == 1:
215
            return reg_type, reg_type
216
        if len(r) == 2:
217
            return reg_type, ''.join(x[:2] for x in r)
218
        return reg_type, ''.join(x[0] for x in r)
219
    elif type(reg_type) == tuple:
220
        return reg_type
221
    else:
222
        raise ValueError("Either input a string representing a regularizer type (immplying that the name should be inferred as an abbreviation) or a tuple holding both the type and the name")
223
224
225
def cfg2regularizer_settings(cfg_file):
226
    config = ConfigParser()
227
    config.read(u'{}'.format(cfg_file))
228
    return OrderedDict([(str(section),
229
                         OrderedDict([(str(setting_name),
230
                                       parameter_name2encoder[str(setting_name)](value)) for setting_name, value in config.items(section) if value])
231
                         ) for section in config.sections()])
232
233
234
def _parse_reg_cfg(regs_config):
235
    reg_initialization_type2_dict = {'dict': lambda x: x,
236
                                     'OrderedDict': lambda x: x,
237
                                     'str': lambda x: cfg2regularizer_settings(x)}
238
    return reg_initialization_type2_dict[type(regs_config).__name__](regs_config)
239
240
241
def _create_reg_settings(self):
242
    regs_init_params = {}
243
    _settings = _parse_reg_cfg(self.reg_cfg)
244
    for reg_unique_type, reg_name in self.regs_hash.items():
245
        try:
246
            regs_init_params[reg_unique_type] = dict(_settings[reg_unique_type], **{'name': reg_name})
247
        except KeyError:
248
            raise KeyError("'reg_cfg' {} resulting in settings {}, does not have key '{}'. Probably you forgot to add the corresponding entry in 'train.cfg' and/or 'regularizers.cfg'".format(self.reg_cfg, _settings, reg_unique_type))
249
                # "Keys in regs_cfg: [{}], keys requested as active regularizers: [{}], current key: {}, name: {}. "
250
                # "Probably you forgot to add the corresponding entry in 'train.cfg' and/or 'regularizers.cfg'".format(
251
                #     ', '.join(sorted(self.reg_cfg.keys())),
252
                #     ', '.join(sorted(self.reg_cfg.keys())),
253
                #     reg_unique_type, reg_name))
254
    return regs_init_params
255
256
257
@attr.s(cmp=True, hash=True, slots=False)
258
class RegularizersData(object):
259
    background_topics = attr.ib(init=True, converter=list, repr=True, cmp=True)
260
    domain_topics = attr.ib(init=True, converter=list, repr=True, cmp=True)
261
    regs_hash = attr.ib(init=True, converter=_parse_active_regs)
262
    reg_cfg = attr.ib(init=True)
263
    regularizers_parameters = attr.ib(default=attr.Factory(lambda self: _create_reg_settings(self), takes_self=True), repr=True, cmp=True, hash=True, init=False)
264