1
|
|
|
import warnings |
2
|
|
|
from collections import OrderedDict, Counter |
3
|
|
|
import attr |
4
|
|
|
from functools import reduce |
5
|
|
|
import pprint |
6
|
|
|
from types import FunctionType |
7
|
|
|
from tqdm import tqdm |
8
|
|
|
from .parameters import ParameterGrid |
9
|
|
|
from ..modeling import TrainerFactory, Experiment |
10
|
|
|
from ..definitions import DEFAULT_CLASS_NAME, IDEOLOGY_CLASS_NAME # this is the name of the default modality. it is irrelevant to class lebels or document lcassification |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
@attr.s |
14
|
|
|
class Versioning(object): |
15
|
|
|
objects = attr.ib(init=True, default=[], converter=Counter, repr=True, cmp=True) |
16
|
|
|
_max_digits_version = attr.ib(init=True, default=2, converter=int, repr=True, cmp=True) |
17
|
|
|
_joiner = '_v' |
18
|
|
|
|
19
|
|
|
def __call__(self, data): |
20
|
|
|
self.objects[data] += 1 |
21
|
|
|
if self.objects.get(data, 0) > 1: |
22
|
|
|
return '{}{}{}'.format(data, self._joiner, self._iter_prepend(self.objects[data]-1)) |
23
|
|
|
return data |
24
|
|
|
|
25
|
|
|
def _iter_prepend(self, int_num): |
26
|
|
|
nb_digits = len(str(int_num)) |
27
|
|
|
if nb_digits < self._max_digits_version: |
28
|
|
|
return '{}{}'.format((self._max_digits_version - nb_digits) * '0', int_num) |
29
|
|
|
if nb_digits == self._max_digits_version: |
30
|
|
|
return str(int_num) |
31
|
|
|
raise RuntimeError("More than 100 items are versioned. (but max_digit_length=2)") |
32
|
|
|
|
33
|
|
|
|
34
|
|
|
@attr.s |
35
|
|
|
class Tuner(object): |
36
|
|
|
dataset = attr.ib(init=True, repr=True) |
37
|
|
|
scores = attr.ib(init=True, default={ |
38
|
|
|
'perplexity': 'per', |
39
|
|
|
'sparsity-phi-@dc': 'sppd', |
40
|
|
|
'sparsity-phi-@ic': 'sppi', |
41
|
|
|
'sparsity-theta': 'spt', |
42
|
|
|
# 'topic-kernel-0.25': 'tk25', |
43
|
|
|
'topic-kernel-0.60': 'tk60', |
44
|
|
|
'topic-kernel-0.80': 'tk80', |
45
|
|
|
'top-tokens-10': 'top10', |
46
|
|
|
'top-tokens-100': 'top100', |
47
|
|
|
'background-tokens-ratio-0.3': 'btr3', |
48
|
|
|
'background-tokens-ratio-0.2': 'btr2' |
49
|
|
|
}) |
50
|
|
|
_training_parameters = attr.ib(init=True, default={}, converter=dict, repr=True) |
51
|
|
|
_reg_specs = attr.ib(init=True, default={}, converter=dict, repr=True) |
52
|
|
|
grid_searcher = attr.ib(init=True, default=None, repr=True) |
53
|
|
|
version = attr.ib(init=True, factory=Versioning, repr=True) |
54
|
|
|
|
55
|
|
|
_labeler = attr.ib(init=False, default=None) |
56
|
|
|
trainer = attr.ib(init=False, default=attr.Factory(lambda self: TrainerFactory().create_trainer(self.dataset, exploit_ideology_labels=True, force_new_batches=False), takes_self=True)) |
57
|
|
|
experiment = attr.ib(init=False, default=attr.Factory(lambda self: Experiment(self.dataset), takes_self=True)) |
58
|
|
|
|
59
|
|
|
def __attrs_post_init__(self): |
60
|
|
|
self.trainer.register(self.experiment) |
61
|
|
|
|
62
|
|
|
def __getitem__(self, item): |
63
|
|
|
if item == 'training': |
64
|
|
|
return self._training_parameters |
65
|
|
|
if item == 'regularization': |
66
|
|
|
return self._reg_specs |
67
|
|
|
raise KeyError |
68
|
|
|
|
69
|
|
|
@property |
70
|
|
|
def parameter_names(self): |
71
|
|
|
return self._training_parameters.parameter_names + self._reg_specs.parameter_names |
72
|
|
|
|
73
|
|
|
@property |
74
|
|
|
def constants(self): |
75
|
|
|
return self._training_parameters.steady + self._reg_specs.steady |
76
|
|
|
|
77
|
|
|
@property |
78
|
|
|
def explorables(self): |
79
|
|
|
return self._training_parameters.explorable + self._reg_specs.explorable |
80
|
|
|
|
81
|
|
|
@property |
82
|
|
|
def training_parameters(self): |
83
|
|
|
"""The mixture of steady parameters and prameters to tune on""" |
84
|
|
|
return self._training_parameters |
85
|
|
|
|
86
|
|
|
@training_parameters.setter |
87
|
|
|
def training_parameters(self, training_parameters): |
88
|
|
|
"""Provide a dict with the mixture of steady parameters and prameters to tune on""" |
89
|
|
|
self._training_parameters = ParametersMixture(training_parameters) |
90
|
|
|
|
91
|
|
|
@property |
92
|
|
|
def regularization_specs(self): |
93
|
|
|
"""The specifications according to which regularization components should be activated, initialized and potentially evolved (see tau trajectory) during training""" |
94
|
|
|
return self._reg_specs |
95
|
|
|
|
96
|
|
|
@regularization_specs.setter |
97
|
|
|
def regularization_specs(self, regularization_specs): |
98
|
|
|
self._reg_specs = RegularizationSpecifications(regularization_specs) |
99
|
|
|
|
100
|
|
|
@property |
101
|
|
|
def current_reg_specs(self): |
102
|
|
|
return {reg_type: |
103
|
|
|
{param_name: |
104
|
|
|
self._val('{}.{}'.format(reg_type, param_name)) |
105
|
|
|
for param_name in params_mixture.parameter_names} for reg_type, params_mixture in self._reg_specs} |
106
|
|
|
|
107
|
|
|
def _set_verbosity_level(self, input_verbose): |
108
|
|
|
try: |
109
|
|
|
self._vb = int(input_verbose) |
110
|
|
|
if self._vb < 0: |
111
|
|
|
self._vb = 0 |
112
|
|
|
elif 5 < self._vb: |
113
|
|
|
self._vb = 5 |
114
|
|
|
except ValueError: |
115
|
|
|
self._vb = 3 |
116
|
|
|
|
117
|
|
|
def tune(self, *args, **kwargs): |
118
|
|
|
self._set_verbosity_level(kwargs.get('verbose', 3)) |
119
|
|
|
|
120
|
|
|
if args: |
121
|
|
|
if len(args) > 0: |
122
|
|
|
self.training_parameters = args[0] |
123
|
|
|
if len(args) > 1: |
124
|
|
|
self.regularization_specs = args[1] |
125
|
|
|
|
126
|
|
|
self._labeler = LabelingDefinition.from_tuner(self, prefix=kwargs.get('prefix_label', ''), labeling_params=kwargs.get('labeling_params', False), |
127
|
|
|
append_static=kwargs.get('append_static', False), append_explorable=kwargs.get('append_explorables', True), |
128
|
|
|
preserve_order=kwargs.get('preserve_order', True), |
129
|
|
|
parameter_set=kwargs.get('parameter_set', 'training')) |
130
|
|
|
|
131
|
|
|
self.grid_searcher = ParameterGrid(self._training_parameters.parameter_spans + [span for _, reg_params_mixture in self._reg_specs for span in reg_params_mixture.parameter_spans]) |
132
|
|
|
|
133
|
|
|
if 1 < self._vb: |
134
|
|
|
print('Taking {} samples for grid-search'.format(len(self.grid_searcher))) |
135
|
|
|
# if kwargs.get('force_overwrite', True): |
136
|
|
|
print('Overwritting any existing results and phi matrices found') |
137
|
|
|
if self._vb: |
138
|
|
|
print('Tuning..') |
139
|
|
|
generator = tqdm(self.grid_searcher, total=len(self.grid_searcher), unit='model') |
140
|
|
|
else: |
141
|
|
|
generator = iter(self.grid_searcher) |
142
|
|
|
|
143
|
|
|
for i, self.parameter_vector in enumerate(generator): |
|
|
|
|
144
|
|
|
self._cur_label = self.version(self._labeler(self.parameter_vector)) |
145
|
|
|
with warnings.catch_warnings(record=True) as w: |
146
|
|
|
# Cause all warnings to always be triggered. |
147
|
|
|
# warnings.simplefilter("always") |
148
|
|
|
tm, specs = self._model() |
149
|
|
|
# assert len(w) == 1 |
150
|
|
|
# assert issubclass(w[-1].category, DeprecationWarning) |
151
|
|
|
# assert "The value of 'probability_mass_threshold' parameter should be set to 0.5 or higher" == str(w[-1].message) |
152
|
|
|
# Trigger a warning. |
153
|
|
|
# Verify some things |
154
|
|
|
tqdm.write(self._cur_label) |
155
|
|
|
tqdm.write("Background: [{}]".format(', '.join([x for x in tm.background_topics]))) |
156
|
|
|
tqdm.write("Domain: [{}]".format(', '.join([x for x in tm.domain_topics]))) |
157
|
|
|
|
158
|
|
|
if 4 < self._vb: |
159
|
|
|
tqdm.write(pprint.pformat({k: dict(v, **{k: v for k, v in { |
160
|
|
|
'target topics': self._topics_str(tm.get_reg_obj(tm.get_reg_name(k)).topic_names, tm.domain_topics, tm.background_topics), |
161
|
|
|
'mods': getattr(tm.get_reg_obj(tm.get_reg_name(k)), 'class_ids', None)}.items()}) for k, v in self.current_reg_specs.items()})) |
162
|
|
|
if 3 < self._vb: |
163
|
|
|
tqdm.write(pprint.pformat(tm.modalities_dictionary)) |
164
|
|
|
self.experiment.init_empty_trackables(tm) |
165
|
|
|
self.trainer.train(tm, specs, cache_theta=kwargs.get('cache_theta', True)) |
166
|
|
|
self.experiment.save_experiment(save_phi=True) |
167
|
|
|
|
168
|
|
|
def _topics_str(self, topics, domain, background): |
169
|
|
|
if topics == domain: |
170
|
|
|
return 'domain' |
171
|
|
|
if topics == background: |
172
|
|
|
return 'background' |
173
|
|
|
return '[{}]'.format(', '.join(topics)) |
174
|
|
|
|
175
|
|
|
def _model(self): |
176
|
|
|
tm = self.trainer.model_factory.construct_model(self._cur_label, self._val('nb_topics'), |
177
|
|
|
self._val('collection_passes'), |
178
|
|
|
self._val('document_passes'), |
179
|
|
|
self._val('background_topics_pct'), |
180
|
|
|
{k: v for k, v in |
181
|
|
|
{DEFAULT_CLASS_NAME: self._val('default_class_weight'), |
182
|
|
|
IDEOLOGY_CLASS_NAME: self._val( |
183
|
|
|
'ideology_class_weight')}.items() if v}, |
184
|
|
|
self.scores, |
185
|
|
|
self._reg_specs.types, |
186
|
|
|
reg_settings=self.current_reg_specs) # a dictionary mapping reg_types to reg_specs |
187
|
|
|
tr_specs = self.trainer.model_factory.create_train_specs(self._val('collection_passes')) |
188
|
|
|
return tm, tr_specs |
189
|
|
|
|
190
|
|
|
def _val(self, parameter_name): |
191
|
|
|
return self.extract(self.parameter_vector, parameter_name.replace('_', '-')) |
192
|
|
|
|
193
|
|
|
def extract(self, parameters_vector, parameter_name): |
194
|
|
|
r = parameter_name.split('.') |
195
|
|
|
if len(r) == 1: |
196
|
|
|
return self._training_parameters.extract(parameters_vector, parameter_name) |
197
|
|
|
elif len(r) == 2: |
198
|
|
|
return self._reg_specs.extract(parameters_vector, r[0], r[1]) |
199
|
|
|
else: |
200
|
|
|
raise ValueError("Either input a training parameter such as 'collection_passes', 'nb_topics', 'ideology_class_weight' or a regularizer's parameter in format such as 'sparse-phi.tau', 'label-regularization-phi-dom-def.tau'") |
201
|
|
|
|
202
|
|
|
def __format_reg(self, reg_specs, reg_type): |
203
|
|
|
if 'name' in reg_specs[reg_type]: |
204
|
|
|
return reg_type, reg_specs[reg_type].pop('name') |
205
|
|
|
return reg_type |
206
|
|
|
|
207
|
|
|
|
208
|
|
|
############## PARAMETERS MIXTURE ############## |
209
|
|
|
|
210
|
|
|
|
211
|
|
|
@attr.s |
212
|
|
|
class RegularizationSpecifications(object): |
213
|
|
|
reg_specs = attr.ib(init=True, converter=lambda x: OrderedDict( |
214
|
|
|
[(reg_type, ParametersMixture([(param_name, value) for param_name, value in reg_specs])) for reg_type, reg_specs in x]), repr=True, cmp=True) |
215
|
|
|
parameter_spans = attr.ib(init=False, default=attr.Factory( |
216
|
|
|
lambda self: [span for reg_type, reg_specs in self.reg_specs.items() for span in reg_specs.parameter_spans], |
217
|
|
|
takes_self=True)) |
218
|
|
|
parameter_names = attr.ib(init=False, |
219
|
|
|
default=attr.Factory(lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.parameter_names], takes_self=True)) |
220
|
|
|
steady = attr.ib(init=False, default=attr.Factory( |
221
|
|
|
lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.steady], |
222
|
|
|
takes_self=True)) |
223
|
|
|
explorable = attr.ib(init=False, default=attr.Factory( |
224
|
|
|
lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.explorable], |
225
|
|
|
takes_self=True)) |
226
|
|
|
|
227
|
|
|
nb_combinations = attr.ib(init=False, default=attr.Factory(lambda self: reduce(lambda i, j: i * j, [v.nb_combinations for v in self.reg_specs.values()]), takes_self=True)) |
228
|
|
|
|
229
|
|
|
def __getitem__(self, item): |
230
|
|
|
return self.reg_specs[item] |
231
|
|
|
|
232
|
|
|
def __iter__(self): |
233
|
|
|
return ((reg_type, reg_params_mixture) for reg_type, reg_params_mixture in self.reg_specs.items()) |
|
|
|
|
234
|
|
|
|
235
|
|
|
def extract(self, parameter_vector, reg_name, reg_param): |
236
|
|
|
parameter_vector = list(parameter_vector[-len(self.parameter_spans):]) |
237
|
|
|
if reg_name not in self.reg_specs: |
238
|
|
|
raise KeyError |
239
|
|
|
s = 0 |
240
|
|
|
for k, v in self.reg_specs.items(): |
241
|
|
|
if k == reg_name: |
242
|
|
|
break |
243
|
|
|
s += v.length |
244
|
|
|
return self.reg_specs[reg_name].extract(parameter_vector[s:], reg_param) |
245
|
|
|
|
246
|
|
|
@property |
247
|
|
|
def types(self): |
248
|
|
|
return list(self.reg_specs.keys()) |
249
|
|
|
|
250
|
|
|
def _conv(value): |
251
|
|
|
if type(value) != list: |
252
|
|
|
return [value] |
253
|
|
|
return value |
254
|
|
|
|
255
|
|
|
|
256
|
|
|
def _build(tuple_list): |
257
|
|
|
if len(tuple_list) != len(set([x[0] for x in tuple_list])): |
258
|
|
|
raise ValueError("Input tuples should behave like a dict (unique elements as each 1st element)") |
259
|
|
|
return OrderedDict([(x[0], _conv(x[1])) for x in tuple_list]) |
260
|
|
|
|
261
|
|
|
@attr.s |
262
|
|
|
class ParametersMixture(object): |
263
|
|
|
"""An OrderedDict with keys parameter names and keys either a single object to initialize with or a list of objects intending to model a span/grid of values for grid-search""" |
264
|
|
|
_data_hash = attr.ib(init=True, converter=_build) |
265
|
|
|
|
266
|
|
|
length = attr.ib(init=False, default=attr.Factory(lambda self: len(self._data_hash), takes_self=True)) |
267
|
|
|
parameter_names = attr.ib(init=False, default=attr.Factory(lambda self: list(self._data_hash.keys()), takes_self=True)) |
268
|
|
|
nb_combinations = attr.ib(init=False, default=attr.Factory(lambda self: reduce(lambda i,j: i*j, [len(v) for v in self._data_hash.values()]), takes_self=True)) |
269
|
|
|
steady = attr.ib(init=False, default=attr.Factory(lambda self: [name for name, assumables_values in self._data_hash.items() if len(assumables_values) == 1], takes_self=True)) |
270
|
|
|
explorable = attr.ib(init=False, default=attr.Factory(lambda self: [name for name, assumables_values in self._data_hash.items() if len(assumables_values) > 1], takes_self=True)) |
271
|
|
|
parameter_spans = attr.ib(init=False, default=attr.Factory(lambda self: [assumable_values for assumable_values in self._data_hash.values()], takes_self=True)) |
272
|
|
|
|
273
|
|
|
def __contains__(self, item): |
274
|
|
|
return item in self._data_hash |
275
|
|
|
|
276
|
|
|
def __len__(self): |
277
|
|
|
return len(self._data_hash) |
278
|
|
|
|
279
|
|
|
def __getitem__(self, item): |
280
|
|
|
return self._data_hash[item] |
281
|
|
|
|
282
|
|
|
def __iter__(self): |
283
|
|
|
return ((k,v) for k,v in self._data_hash.items()) |
|
|
|
|
284
|
|
|
|
285
|
|
|
def extract(self, parameter_vector, parameter): |
286
|
|
|
return parameter_vector[self.parameter_names.index(parameter)] |
287
|
|
|
|
288
|
|
|
@classmethod |
289
|
|
|
def from_regularization_settings(cls, reg_settings): |
290
|
|
|
return ParametersMixture([(k,v) for k, v in reg_settings]) |
291
|
|
|
|
292
|
|
|
############## LABELING ############## |
293
|
|
|
|
294
|
|
|
def _check_extractor(self, attr_attribute_obj, input_value): |
295
|
|
|
if not hasattr(input_value, '__call__'): |
296
|
|
|
raise ValueError("A callable is required") |
297
|
|
|
if input_value.__code__.co_argcount != 2: |
298
|
|
|
raise ValueError("Callable should accept exactly 2 arguments. First is a parameter vector (list) and second a parameter name") |
299
|
|
|
|
300
|
|
|
|
301
|
|
|
@attr.s |
302
|
|
|
class LabelingDefinition(object): |
303
|
|
|
parameters = attr.ib(init=True, converter=list, repr=True, cmp=True) |
304
|
|
|
extractor = attr.ib(init=True, validator=_check_extractor) |
305
|
|
|
_prefix = attr.ib(init=True, default='') #, converter=str) |
306
|
|
|
|
307
|
|
|
def __call__(self, parameter_vector): |
308
|
|
|
return '_'.join(x for x in [self._prefix] + [self._conv(self.extractor(parameter_vector, param_name)) for param_name in self.parameters] if x) |
309
|
|
|
|
310
|
|
|
def _conv(self, v): |
311
|
|
|
try: |
312
|
|
|
v1 = float(v) |
313
|
|
|
if v1 >= 1e4: |
314
|
|
|
return "{:.2}".format(v1) |
315
|
|
|
if int(v1) == v1: |
316
|
|
|
return str(int(v1)) |
317
|
|
|
return str(v) |
318
|
|
|
except ValueError: |
319
|
|
|
return str(v) |
320
|
|
|
|
321
|
|
|
@classmethod |
322
|
|
|
def from_training_parameters(cls, training_parameters, prefix='', labeling_params=None, append_static=False, append_explorable=True, preserve_order=False): |
323
|
|
|
if labeling_params: |
324
|
|
|
if not type(labeling_params) == list: |
325
|
|
|
raise ValueError("If given the labeling_params argument should be a list") |
326
|
|
|
else: |
327
|
|
|
if type(append_static) == list: |
328
|
|
|
labeling_params = append_static |
329
|
|
|
elif append_static: |
330
|
|
|
labeling_params = training_parameters.steady.copy() |
331
|
|
|
else: |
332
|
|
|
labeling_params = [] |
333
|
|
|
if type(append_explorable) == list: |
334
|
|
|
labeling_params.extend(append_explorable) |
335
|
|
|
elif append_explorable: |
336
|
|
|
labeling_params.extend(training_parameters.explorable) |
337
|
|
|
if preserve_order: |
338
|
|
|
labeling_params = [x for x in training_parameters.parameter_names if x in labeling_params] |
339
|
|
|
return LabelingDefinition(labeling_params, lambda vector, param: training_parameters.extract(vector, param), prefix) |
340
|
|
|
|
341
|
|
|
@classmethod |
342
|
|
|
def from_tuner(cls, tuner, prefix='', labeling_params=None, append_static=False, append_explorable=True, preserve_order=False, parameter_set='training|regularization'): |
343
|
|
|
return LabelingDefinition(cls.select(tuner, labeling_params=labeling_params, append_static=append_static, append_explorable=append_explorable, preserve_order=preserve_order, parameter_set=parameter_set), |
344
|
|
|
lambda vector, param: tuner.extract(vector, param), |
345
|
|
|
prefix) |
346
|
|
|
|
347
|
|
|
@classmethod |
348
|
|
|
def select(cls, tuner, **kwargs): |
349
|
|
|
labeling_params = [] |
350
|
|
|
if kwargs.get('labeling_params', None): |
351
|
|
|
if not type(kwargs.get('labeling_params', None)) == list: |
352
|
|
|
raise ValueError("If given, the labeling_params keyword-argument should be a list") |
353
|
|
|
labeling_params = kwargs['labeling_params'] |
354
|
|
|
else: |
355
|
|
|
if type(kwargs.get('append_static', False)) == list: |
356
|
|
|
labeling_params.extend(kwargs['append_static']) |
357
|
|
|
elif kwargs.get('append_static', False): |
358
|
|
|
labeling_params.extend([x for el in kwargs.get('parameter_set', 'training|regularization').split('|') for x in tuner[el].steady]) |
359
|
|
|
if type(kwargs.get('append_explorable', False)) == list: |
360
|
|
|
labeling_params.extend(kwargs['append_explorable']) |
361
|
|
|
elif kwargs.get('append_explorable', False): |
362
|
|
|
labeling_params.extend([x for el in kwargs.get('parameter_set', 'training|regularization').split('|') for x in tuner[el].explorable]) |
363
|
|
|
if kwargs.get('preserve_order', False): |
364
|
|
|
# labeling_params = [x for x in [y for el in kwargs.get('parameter_set', 'training|regularization').split('|') for y in tuner[el].parameter_names] if x in labeling_params] |
365
|
|
|
labeling_params = [x for x in tuner.parameter_names if x in labeling_params] |
366
|
|
|
return labeling_params |
367
|
|
|
|