patm.tuning.tuner   F
last analyzed

Complexity

Total Complexity 85

Size/Duplication

Total Lines 367
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 287
dl 0
loc 367
rs 2
c 0
b 0
f 0
wmc 85

32 Methods

Rating   Name   Duplication   Size   Complexity  
A ParametersMixture.__contains__() 0 2 1
A Versioning.__call__() 0 5 2
A ParametersMixture.__len__() 0 2 1
A ParametersMixture.from_regularization_settings() 0 3 1
A Tuner.training_parameters() 0 4 1
C Tuner.tune() 0 50 10
C LabelingDefinition.from_training_parameters() 0 19 9
A RegularizationSpecifications.types() 0 3 1
A RegularizationSpecifications.extract() 0 10 4
A Tuner.extract() 0 8 3
A Tuner.__format_reg() 0 4 2
A RegularizationSpecifications.__getitem__() 0 2 1
A Tuner._model() 0 14 1
A LabelingDefinition.from_tuner() 0 5 2
A Tuner.constants() 0 3 1
A Tuner.explorables() 0 3 1
A Tuner.__getitem__() 0 6 3
A Versioning._iter_prepend() 0 7 3
A RegularizationSpecifications.__iter__() 0 2 1
B LabelingDefinition.select() 0 20 8
A Tuner.regularization_specs() 0 4 1
A Tuner._topics_str() 0 6 3
A ParametersMixture.extract() 0 2 1
A ParametersMixture.__getitem__() 0 2 1
A Tuner.parameter_names() 0 3 1
A LabelingDefinition._conv() 0 10 4
A ParametersMixture.__iter__() 0 2 1
A Tuner._val() 0 2 1
A Tuner.current_reg_specs() 0 6 1
A LabelingDefinition.__call__() 0 2 1
A Tuner._set_verbosity_level() 0 9 4
A Tuner.__attrs_post_init__() 0 2 1

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _check_extractor() 0 5 3
A _build() 0 4 2
A _conv() 0 4 2

How to fix   Complexity   

Complexity

Complex classes like patm.tuning.tuner often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
from collections import OrderedDict, Counter
3
import attr
4
from functools import reduce
5
import pprint
6
from types import FunctionType
7
from tqdm import tqdm
8
from .parameters import ParameterGrid
9
from ..modeling import TrainerFactory, Experiment
10
from ..definitions import DEFAULT_CLASS_NAME, IDEOLOGY_CLASS_NAME  # this is the name of the default modality. it is irrelevant to class lebels or document lcassification
11
12
13
@attr.s
14
class Versioning(object):
15
    objects = attr.ib(init=True, default=[], converter=Counter, repr=True, cmp=True)
16
    _max_digits_version = attr.ib(init=True, default=2, converter=int, repr=True, cmp=True)
17
    _joiner = '_v'
18
19
    def __call__(self, data):
20
        self.objects[data] += 1
21
        if self.objects.get(data, 0) > 1:
22
            return '{}{}{}'.format(data, self._joiner, self._iter_prepend(self.objects[data]-1))
23
        return data
24
25
    def _iter_prepend(self, int_num):
26
        nb_digits = len(str(int_num))
27
        if nb_digits < self._max_digits_version:
28
            return '{}{}'.format((self._max_digits_version - nb_digits) * '0', int_num)
29
        if nb_digits == self._max_digits_version:
30
            return str(int_num)
31
        raise RuntimeError("More than 100 items are versioned. (but max_digit_length=2)")
32
33
34
@attr.s
35
class Tuner(object):
36
    dataset = attr.ib(init=True, repr=True)
37
    scores = attr.ib(init=True, default={
38
        'perplexity': 'per',
39
        'sparsity-phi-@dc': 'sppd',
40
        'sparsity-phi-@ic': 'sppi',
41
        'sparsity-theta': 'spt',
42
        # 'topic-kernel-0.25': 'tk25',
43
        'topic-kernel-0.60': 'tk60',
44
        'topic-kernel-0.80': 'tk80',
45
        'top-tokens-10': 'top10',
46
        'top-tokens-100': 'top100',
47
        'background-tokens-ratio-0.3': 'btr3',
48
        'background-tokens-ratio-0.2': 'btr2'
49
    })
50
    _training_parameters = attr.ib(init=True, default={}, converter=dict, repr=True)
51
    _reg_specs = attr.ib(init=True, default={}, converter=dict, repr=True)
52
    grid_searcher = attr.ib(init=True, default=None, repr=True)
53
    version = attr.ib(init=True, factory=Versioning, repr=True)
54
55
    _labeler = attr.ib(init=False, default=None)
56
    trainer = attr.ib(init=False, default=attr.Factory(lambda self: TrainerFactory().create_trainer(self.dataset, exploit_ideology_labels=True, force_new_batches=False), takes_self=True))
57
    experiment = attr.ib(init=False, default=attr.Factory(lambda self: Experiment(self.dataset), takes_self=True))
58
59
    def __attrs_post_init__(self):
60
        self.trainer.register(self.experiment)
61
62
    def __getitem__(self, item):
63
        if item == 'training':
64
            return self._training_parameters
65
        if item == 'regularization':
66
            return self._reg_specs
67
        raise KeyError
68
69
    @property
70
    def parameter_names(self):
71
        return self._training_parameters.parameter_names + self._reg_specs.parameter_names
72
73
    @property
74
    def constants(self):
75
        return self._training_parameters.steady + self._reg_specs.steady
76
77
    @property
78
    def explorables(self):
79
        return self._training_parameters.explorable + self._reg_specs.explorable
80
81
    @property
82
    def training_parameters(self):
83
        """The mixture of steady parameters and prameters to tune on"""
84
        return self._training_parameters
85
86
    @training_parameters.setter
87
    def training_parameters(self, training_parameters):
88
        """Provide a dict with the mixture of steady parameters and prameters to tune on"""
89
        self._training_parameters = ParametersMixture(training_parameters)
90
91
    @property
92
    def regularization_specs(self):
93
        """The specifications according to which regularization components should be activated, initialized and potentially evolved (see tau trajectory) during training"""
94
        return self._reg_specs
95
96
    @regularization_specs.setter
97
    def regularization_specs(self, regularization_specs):
98
        self._reg_specs = RegularizationSpecifications(regularization_specs)
99
100
    @property
101
    def current_reg_specs(self):
102
        return {reg_type:
103
                    {param_name:
104
                         self._val('{}.{}'.format(reg_type, param_name))
105
                     for param_name in params_mixture.parameter_names} for reg_type, params_mixture in self._reg_specs}
106
107
    def _set_verbosity_level(self, input_verbose):
108
        try:
109
            self._vb = int(input_verbose)
110
            if self._vb < 0:
111
                self._vb = 0
112
            elif 5 < self._vb:
113
                self._vb = 5
114
        except ValueError:
115
            self._vb = 3
116
117
    def tune(self, *args, **kwargs):
118
        self._set_verbosity_level(kwargs.get('verbose', 3))
119
        
120
        if args:
121
            if len(args) > 0:
122
                self.training_parameters = args[0]
123
            if len(args) > 1:
124
                self.regularization_specs = args[1]
125
126
        self._labeler = LabelingDefinition.from_tuner(self, prefix=kwargs.get('prefix_label', ''), labeling_params=kwargs.get('labeling_params', False),
127
                                                      append_static=kwargs.get('append_static', False), append_explorable=kwargs.get('append_explorables', True),
128
                                                      preserve_order=kwargs.get('preserve_order', True),
129
                                                      parameter_set=kwargs.get('parameter_set', 'training'))
130
131
        self.grid_searcher = ParameterGrid(self._training_parameters.parameter_spans + [span for _, reg_params_mixture in self._reg_specs for span in reg_params_mixture.parameter_spans])
132
        
133
        if 1 < self._vb:
134
            print('Taking {} samples for grid-search'.format(len(self.grid_searcher)))
135
        # if kwargs.get('force_overwrite', True):
136
        print('Overwritting any existing results and phi matrices found')
137
        if self._vb:
138
            print('Tuning..')
139
            generator = tqdm(self.grid_searcher, total=len(self.grid_searcher), unit='model')
140
        else:
141
            generator = iter(self.grid_searcher)
142
        
143
        for i, self.parameter_vector in enumerate(generator):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable self does not seem to be defined.
Loading history...
144
            self._cur_label = self.version(self._labeler(self.parameter_vector))
145
            with warnings.catch_warnings(record=True) as w:
146
                # Cause all warnings to always be triggered.
147
                # warnings.simplefilter("always")
148
                tm, specs = self._model()
149
                # assert len(w) == 1
150
                # assert issubclass(w[-1].category, DeprecationWarning)
151
                # assert "The value of 'probability_mass_threshold' parameter should be set to 0.5 or higher" == str(w[-1].message)
152
                # Trigger a warning.
153
                # Verify some things
154
            tqdm.write(self._cur_label)
155
            tqdm.write("Background: [{}]".format(', '.join([x for x in tm.background_topics])))
156
            tqdm.write("Domain: [{}]".format(', '.join([x for x in tm.domain_topics])))
157
            
158
            if 4 < self._vb:
159
                tqdm.write(pprint.pformat({k: dict(v, **{k: v for k, v in {
160
                    'target topics': self._topics_str(tm.get_reg_obj(tm.get_reg_name(k)).topic_names, tm.domain_topics, tm.background_topics),
161
                    'mods': getattr(tm.get_reg_obj(tm.get_reg_name(k)), 'class_ids', None)}.items()}) for k, v in self.current_reg_specs.items()}))
162
            if 3 < self._vb:
163
                tqdm.write(pprint.pformat(tm.modalities_dictionary))
164
            self.experiment.init_empty_trackables(tm)
165
            self.trainer.train(tm, specs, cache_theta=kwargs.get('cache_theta', True))
166
            self.experiment.save_experiment(save_phi=True)
167
168
    def _topics_str(self, topics, domain, background):
169
        if topics == domain:
170
            return 'domain'
171
        if topics == background:
172
            return 'background'
173
        return '[{}]'.format(', '.join(topics))
174
175
    def _model(self):
176
        tm = self.trainer.model_factory.construct_model(self._cur_label, self._val('nb_topics'),
177
                                                        self._val('collection_passes'),
178
                                                        self._val('document_passes'),
179
                                                        self._val('background_topics_pct'),
180
                                                        {k: v for k, v in
181
                                                         {DEFAULT_CLASS_NAME: self._val('default_class_weight'),
182
                                                          IDEOLOGY_CLASS_NAME: self._val(
183
                                                              'ideology_class_weight')}.items() if v},
184
                                                        self.scores,
185
                                                        self._reg_specs.types,
186
                                                        reg_settings=self.current_reg_specs)  # a dictionary mapping reg_types to reg_specs
187
        tr_specs = self.trainer.model_factory.create_train_specs(self._val('collection_passes'))
188
        return tm, tr_specs
189
190
    def _val(self, parameter_name):
191
        return self.extract(self.parameter_vector, parameter_name.replace('_', '-'))
192
193
    def extract(self, parameters_vector, parameter_name):
194
        r = parameter_name.split('.')
195
        if len(r) == 1:
196
            return self._training_parameters.extract(parameters_vector, parameter_name)
197
        elif len(r) == 2:
198
            return self._reg_specs.extract(parameters_vector, r[0], r[1])
199
        else:
200
            raise ValueError("Either input a training parameter such as 'collection_passes', 'nb_topics', 'ideology_class_weight' or a regularizer's parameter in format such as 'sparse-phi.tau', 'label-regularization-phi-dom-def.tau'")
201
202
    def __format_reg(self, reg_specs, reg_type):
203
        if 'name' in reg_specs[reg_type]:
204
            return reg_type, reg_specs[reg_type].pop('name')
205
        return reg_type
206
207
208
############## PARAMETERS MIXTURE ##############
209
210
211
@attr.s
212
class RegularizationSpecifications(object):
213
    reg_specs = attr.ib(init=True, converter=lambda x: OrderedDict(
214
        [(reg_type, ParametersMixture([(param_name, value) for param_name, value in reg_specs])) for reg_type, reg_specs in x]), repr=True, cmp=True)
215
    parameter_spans = attr.ib(init=False, default=attr.Factory(
216
        lambda self: [span for reg_type, reg_specs in self.reg_specs.items() for span in reg_specs.parameter_spans],
217
        takes_self=True))
218
    parameter_names = attr.ib(init=False,
219
                              default=attr.Factory(lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.parameter_names], takes_self=True))
220
    steady = attr.ib(init=False, default=attr.Factory(
221
        lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.steady],
222
        takes_self=True))
223
    explorable = attr.ib(init=False, default=attr.Factory(
224
        lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.explorable],
225
        takes_self=True))
226
227
    nb_combinations = attr.ib(init=False, default=attr.Factory(lambda self: reduce(lambda i, j: i * j, [v.nb_combinations for v in self.reg_specs.values()]), takes_self=True))
228
229
    def __getitem__(self, item):
230
        return self.reg_specs[item]
231
232
    def __iter__(self):
233
        return ((reg_type, reg_params_mixture) for reg_type, reg_params_mixture in self.reg_specs.items())
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable reg_type does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable reg_params_mixture does not seem to be defined.
Loading history...
234
235
    def extract(self, parameter_vector, reg_name, reg_param):
236
        parameter_vector = list(parameter_vector[-len(self.parameter_spans):])
237
        if reg_name not in self.reg_specs:
238
            raise KeyError
239
        s = 0
240
        for k, v in self.reg_specs.items():
241
            if k == reg_name:
242
                break
243
            s += v.length
244
        return self.reg_specs[reg_name].extract(parameter_vector[s:], reg_param)
245
246
    @property
247
    def types(self):
248
        return list(self.reg_specs.keys())
249
250
def _conv(value):
251
    if type(value) != list:
252
        return [value]
253
    return value
254
255
256
def _build(tuple_list):
257
    if len(tuple_list) != len(set([x[0] for x in tuple_list])):
258
        raise ValueError("Input tuples should behave like a dict (unique elements as each 1st element)")
259
    return OrderedDict([(x[0], _conv(x[1])) for x in tuple_list])
260
261
@attr.s
262
class ParametersMixture(object):
263
    """An OrderedDict with keys parameter names and keys either a single object to initialize with or a list of objects intending to model a span/grid of values for grid-search"""
264
    _data_hash = attr.ib(init=True, converter=_build)
265
266
    length = attr.ib(init=False, default=attr.Factory(lambda self: len(self._data_hash), takes_self=True))
267
    parameter_names = attr.ib(init=False, default=attr.Factory(lambda self: list(self._data_hash.keys()), takes_self=True))
268
    nb_combinations = attr.ib(init=False, default=attr.Factory(lambda self: reduce(lambda i,j: i*j, [len(v) for v in self._data_hash.values()]), takes_self=True))
269
    steady = attr.ib(init=False, default=attr.Factory(lambda self: [name for name, assumables_values in self._data_hash.items() if len(assumables_values) == 1], takes_self=True))
270
    explorable = attr.ib(init=False, default=attr.Factory(lambda self: [name for name, assumables_values in self._data_hash.items() if len(assumables_values) > 1], takes_self=True))
271
    parameter_spans = attr.ib(init=False, default=attr.Factory(lambda self: [assumable_values for assumable_values in self._data_hash.values()], takes_self=True))
272
273
    def __contains__(self, item):
274
        return item in self._data_hash
275
276
    def __len__(self):
277
        return len(self._data_hash)
278
279
    def __getitem__(self, item):
280
        return self._data_hash[item]
281
282
    def __iter__(self):
283
        return ((k,v) for k,v in self._data_hash.items())
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable v does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable k does not seem to be defined.
Loading history...
284
285
    def extract(self, parameter_vector, parameter):
286
        return parameter_vector[self.parameter_names.index(parameter)]
287
288
    @classmethod
289
    def from_regularization_settings(cls, reg_settings):
290
        return ParametersMixture([(k,v) for k, v in reg_settings])
291
292
##############  LABELING ##############
293
294
def _check_extractor(self, attr_attribute_obj, input_value):
295
    if not hasattr(input_value, '__call__'):
296
        raise ValueError("A callable is required")
297
    if input_value.__code__.co_argcount != 2:
298
        raise ValueError("Callable should accept exactly 2 arguments. First is a parameter vector (list) and second a parameter name")
299
300
301
@attr.s
302
class LabelingDefinition(object):
303
    parameters = attr.ib(init=True, converter=list, repr=True, cmp=True)
304
    extractor = attr.ib(init=True, validator=_check_extractor)
305
    _prefix = attr.ib(init=True, default='') #, converter=str)
306
307
    def __call__(self, parameter_vector):
308
        return '_'.join(x for x in [self._prefix] + [self._conv(self.extractor(parameter_vector, param_name)) for param_name in self.parameters] if x)
309
310
    def _conv(self, v):
311
        try:
312
            v1 = float(v)
313
            if v1 >= 1e4:
314
                return "{:.2}".format(v1)
315
            if int(v1) == v1:
316
                return str(int(v1))
317
            return str(v)
318
        except ValueError:
319
            return str(v)
320
321
    @classmethod
322
    def from_training_parameters(cls, training_parameters, prefix='', labeling_params=None, append_static=False, append_explorable=True, preserve_order=False):
323
        if labeling_params:
324
            if not type(labeling_params) == list:
325
                raise ValueError("If given the labeling_params argument should be a list")
326
        else:
327
            if type(append_static) == list:
328
                labeling_params = append_static
329
            elif append_static:
330
                labeling_params = training_parameters.steady.copy()
331
            else:
332
                labeling_params = []
333
            if type(append_explorable) == list:
334
                labeling_params.extend(append_explorable)
335
            elif append_explorable:
336
                labeling_params.extend(training_parameters.explorable)
337
        if preserve_order:
338
            labeling_params = [x for x in training_parameters.parameter_names if x in labeling_params]
339
        return LabelingDefinition(labeling_params, lambda vector, param: training_parameters.extract(vector, param), prefix)
340
341
    @classmethod
342
    def from_tuner(cls, tuner, prefix='', labeling_params=None, append_static=False, append_explorable=True, preserve_order=False, parameter_set='training|regularization'):
343
        return LabelingDefinition(cls.select(tuner, labeling_params=labeling_params, append_static=append_static, append_explorable=append_explorable, preserve_order=preserve_order, parameter_set=parameter_set),
344
                                  lambda vector, param: tuner.extract(vector, param),
345
                                  prefix)
346
347
    @classmethod
348
    def select(cls, tuner, **kwargs):
349
        labeling_params = []
350
        if kwargs.get('labeling_params', None):
351
            if not type(kwargs.get('labeling_params', None)) == list:
352
                raise ValueError("If given, the labeling_params keyword-argument should be a list")
353
            labeling_params = kwargs['labeling_params']
354
        else:
355
            if type(kwargs.get('append_static', False)) == list:
356
                labeling_params.extend(kwargs['append_static'])
357
            elif kwargs.get('append_static', False):
358
                labeling_params.extend([x for el in kwargs.get('parameter_set', 'training|regularization').split('|') for x in tuner[el].steady])
359
            if type(kwargs.get('append_explorable', False)) == list:
360
                labeling_params.extend(kwargs['append_explorable'])
361
            elif kwargs.get('append_explorable', False):
362
                labeling_params.extend([x for el in kwargs.get('parameter_set', 'training|regularization').split('|') for x in tuner[el].explorable])
363
        if kwargs.get('preserve_order', False):
364
            # labeling_params = [x for x in [y for el in kwargs.get('parameter_set', 'training|regularization').split('|') for y in tuner[el].parameter_names] if x in labeling_params]
365
            labeling_params = [x for x in tuner.parameter_names if x in labeling_params]
366
        return labeling_params
367