1
|
|
|
from .model_factory import ModelFactory |
2
|
|
|
from .persistence import ResultsWL, ModelWL |
3
|
|
|
from .regularization.regularizers_factory import REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH as DYN_COEFS |
4
|
|
|
|
5
|
|
|
|
6
|
|
|
class Experiment: |
7
|
|
|
""" |
8
|
|
|
This class encapsulates experimental related activities such as tracking of various quantities during model training |
9
|
|
|
and also has capabilities of persisting these tracked quantities as experimental results.\n |
10
|
|
|
Experimental results include: |
11
|
|
|
- number of topics: the number of topics inffered |
12
|
|
|
- document passes: the inner iterations over each document. Phi matrix gets updated 'document_passes' x 'nb_docs' times during one path through the whole collection |
13
|
|
|
- root dir: the base directory of the colletion, on which experiments were conducted |
14
|
|
|
- model label: the unique identifier of the model used for the experiments |
15
|
|
|
- trackable "evaluation" metrics |
16
|
|
|
- regularization parameters |
17
|
|
|
- _evaluator_name2definition |
18
|
|
|
""" |
19
|
|
|
|
20
|
|
|
|
21
|
|
|
MAX_DECIMALS = 2 |
22
|
|
|
def __init__(self, dataset_dir): |
23
|
|
|
""" |
24
|
|
|
Encapsulates experimentation by doing topic modeling on a dataset/'collection' in the given patm_root_dir. A 'collection' is a proccessed document collection into BoW format and possibly split into 'train' and 'test' splits.\n |
25
|
|
|
:param str dataset_dir: the full path to a dataset/collection specific directory |
26
|
|
|
:param str cooc_dict: the full path to a dataset/collection specific directory |
27
|
|
|
""" |
28
|
|
|
self._dir = dataset_dir |
29
|
|
|
# self.cooc_dict = cooc_dict |
30
|
|
|
self._loaded_dictionary = None # artm.Dictionary object. Data population happens uppon artm.Artm object creation in model_factory; dictionary.load(bin_dict_path) is called there |
31
|
|
|
self._topic_model = None |
32
|
|
|
self.collection_passes = [] |
33
|
|
|
self.trackables = None |
34
|
|
|
# self.train_results_handler = ResultsWL.from_experiment(self) |
35
|
|
|
# self.phi_matrix_handler = ModelWL.from_experiment(self) |
36
|
|
|
self.train_results_handler = ResultsWL(self) |
37
|
|
|
self.phi_matrix_handler = ModelWL(self) |
38
|
|
|
self.failed_top_tokens_coherence = {} |
39
|
|
|
self._total_passes = 0 |
40
|
|
|
self.regularizers_dynamic_parameters = None |
41
|
|
|
self._last_tokens = {} |
42
|
|
|
|
43
|
|
|
def init_empty_trackables(self, model): |
44
|
|
|
self._topic_model = model |
45
|
|
|
self.trackables = {} |
46
|
|
|
self.failed_top_tokens_coherence = {} |
47
|
|
|
self._total_passes = 0 |
48
|
|
|
for evaluator_definition in model.evaluator_definitions: |
49
|
|
|
if self._strip_parameters(evaluator_definition) in ('perplexity', 'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'): |
50
|
|
|
self.trackables[Experiment._assert_max_decimals(evaluator_definition)] = [] |
51
|
|
|
elif evaluator_definition.startswith('topic-kernel-'): |
52
|
|
|
self.trackables[Experiment._assert_max_decimals(evaluator_definition)] = [[], [], [], [], {t_name: {'coherence': [], 'contrast': [], 'purity': [], 'size': []} for t_name in model.domain_topics}] |
53
|
|
|
elif evaluator_definition.startswith('top-tokens-'): |
54
|
|
|
self.trackables[evaluator_definition] = [[], {t_name: [] for t_name in model.domain_topics}] |
55
|
|
|
self.failed_top_tokens_coherence[evaluator_definition] = {t_name: [] for t_name in model.domain_topics} |
56
|
|
|
self.collection_passes = [] |
57
|
|
|
if not all(x in DYN_COEFS for _, x in model.long_types_n_types): |
58
|
|
|
raise KeyError("One of the [{}] 'reg_type' (not unique types!) should be added in the REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH in the regularizers_factory module".format(', '.join("'{}'".format(x) for _, x in model.long_types_n_types))) |
59
|
|
|
self.regularizers_dynamic_parameters = {unique_type: {attr: [] for attr in DYN_COEFS[reg_type]} for unique_type, reg_type in model.long_types_n_types} |
60
|
|
|
|
61
|
|
|
@property |
62
|
|
|
def dataset_iterations(self): |
63
|
|
|
return sum(self.collection_passes) |
64
|
|
|
|
65
|
|
|
@property |
66
|
|
|
def model_factory(self): |
67
|
|
|
return ModelFactory(self.dictionary) |
68
|
|
|
|
69
|
|
|
@property |
70
|
|
|
def topic_model(self): |
71
|
|
|
""" |
72
|
|
|
:rtype: patm.modeling.topic_model.TopicModel |
73
|
|
|
""" |
74
|
|
|
return self._topic_model |
75
|
|
|
|
76
|
|
|
@staticmethod |
77
|
|
|
def _strip_parameters(score_definition): |
78
|
|
|
tokens = [] |
79
|
|
|
for el in score_definition.split('-'): |
80
|
|
|
try: |
81
|
|
|
_ = float(el) |
82
|
|
|
except ValueError: |
83
|
|
|
if el[0] != '@': |
84
|
|
|
tokens.append(el) |
85
|
|
|
return '-'.join(tokens) |
86
|
|
|
|
87
|
|
|
@classmethod |
88
|
|
|
def _assert_max_decimals(cls, definition): |
89
|
|
|
"""Converts like:\n |
90
|
|
|
- 'topic-kernel-0.6' -> 'topic-kernel-0.60'\n |
91
|
|
|
- 'topic-kernel-0.871' -> 'topic-kernel-0.87'\n |
92
|
|
|
- 'background-tokens-ratio-0.3' -> 'background-tokens-ratio-0.30' |
93
|
|
|
""" |
94
|
|
|
s = definition.split('-') |
95
|
|
|
sl = s[-1] |
96
|
|
|
try: |
97
|
|
|
_ = float(sl) |
98
|
|
|
if len(sl) < 4: |
99
|
|
|
sl = '{}{}'.format(sl, '0'*(2 + cls.MAX_DECIMALS- len(sl))) |
100
|
|
|
else: |
101
|
|
|
sl = sl[:4] |
102
|
|
|
return '-'.join(s[:-1] + [sl]) |
103
|
|
|
except ValueError: |
104
|
|
|
return definition |
105
|
|
|
|
106
|
|
|
@property |
107
|
|
|
def dictionary(self): |
108
|
|
|
"""This dictionary is passed to Perplexity artm scorer""" |
109
|
|
|
return self._loaded_dictionary |
110
|
|
|
|
111
|
|
|
@dictionary.setter |
112
|
|
|
def dictionary(self, artm_dictionary): |
113
|
|
|
self._loaded_dictionary = artm_dictionary |
114
|
|
|
|
115
|
|
|
# TODO refactor this; remove dubious exceptions |
116
|
|
|
def update(self, topic_model, span): |
117
|
|
|
self.collection_passes.append(span) # iterations performed on the train set for the current 'steady' chunk |
118
|
|
|
|
119
|
|
|
for unique_reg_type, reg_settings in list(topic_model.get_regs_param_dict().items()): |
120
|
|
|
for param_name, param_value in list(reg_settings.items()): |
121
|
|
|
self.regularizers_dynamic_parameters[unique_reg_type][param_name].extend([param_value] * span) |
122
|
|
|
for evaluator_name, evaluator_definition in zip(topic_model.evaluator_names, topic_model.evaluator_definitions): |
123
|
|
|
reportable_to_results = topic_model.get_evaluator(evaluator_name).evaluate(topic_model.artm_model) |
124
|
|
|
definition_with_max_decimals = Experiment._assert_max_decimals(evaluator_definition) |
125
|
|
|
|
126
|
|
|
if self._strip_parameters(evaluator_definition) in ('perplexity', 'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'): |
127
|
|
|
self.trackables[definition_with_max_decimals].extend(reportable_to_results['value'][-span:]) |
128
|
|
|
elif evaluator_definition.startswith('topic-kernel-'): |
129
|
|
|
self.trackables[definition_with_max_decimals][0].extend(reportable_to_results['average_coherence'][-span:]) |
130
|
|
|
self.trackables[definition_with_max_decimals][1].extend(reportable_to_results['average_contrast'][-span:]) |
131
|
|
|
self.trackables[definition_with_max_decimals][2].extend(reportable_to_results['average_purity'][-span:]) |
132
|
|
|
self.trackables[definition_with_max_decimals][3].extend(reportable_to_results['average_size'][-span:]) |
133
|
|
|
for topic_name, topic_metrics in list(self.trackables[definition_with_max_decimals][4].items()): |
134
|
|
|
topic_metrics['coherence'].extend([x[topic_name] for x in reportable_to_results['coherence'][-span:]]) |
135
|
|
|
topic_metrics['contrast'].extend([x[topic_name] for x in reportable_to_results['contrast'][-span:]]) |
136
|
|
|
topic_metrics['purity'].extend([x[topic_name] for x in reportable_to_results['purity'][-span:]]) |
137
|
|
|
topic_metrics['size'].extend([x[topic_name] for x in reportable_to_results['size'][-span:]]) |
138
|
|
|
elif evaluator_definition.startswith('top-tokens-'): |
139
|
|
|
self.trackables[evaluator_definition][0].extend(reportable_to_results['average_coherence'][-span:]) |
140
|
|
|
|
141
|
|
|
assert 'coherence' in reportable_to_results |
142
|
|
|
for topic_name, topic_metrics in list(self.trackables[evaluator_definition][1].items()): |
143
|
|
|
try: |
144
|
|
|
topic_metrics.extend([x[topic_name] for x in reportable_to_results['coherence'][-span:]]) |
145
|
|
|
except KeyError: |
146
|
|
|
if len(self.failed_top_tokens_coherence[evaluator_definition][topic_name]) == 0: |
147
|
|
|
self.failed_top_tokens_coherence[evaluator_definition][topic_name].append((self._total_passes, span)) |
148
|
|
|
else: |
149
|
|
|
if span + self.failed_top_tokens_coherence[evaluator_definition][topic_name][-1][0] == self._total_passes: |
150
|
|
|
perv_tuple = self.failed_top_tokens_coherence[evaluator_definition][topic_name][-1] |
151
|
|
|
self.failed_top_tokens_coherence[evaluator_definition][topic_name][-1] = (self._total_passes, perv_tuple[1]+span) |
152
|
|
|
self._total_passes += span |
153
|
|
|
self.final_tokens = { |
154
|
|
|
'topic-kernel': {eval_def: self._get_final_tokens(eval_def) for eval_def in |
155
|
|
|
self.topic_model.evaluator_definitions if eval_def.startswith('topic-kernel-')}, |
156
|
|
|
'top-tokens': {eval_def: self._get_final_tokens(eval_def) for eval_def in |
157
|
|
|
self.topic_model.evaluator_definitions if eval_def.startswith('top-tokens-')}, |
158
|
|
|
'background-tokens': self.topic_model.background_tokens |
159
|
|
|
} |
160
|
|
|
|
161
|
|
|
@property |
162
|
|
|
def current_root_dir(self): |
163
|
|
|
return self._dir |
164
|
|
|
|
165
|
|
|
def save_experiment(self, save_phi=True): |
166
|
|
|
"""Dumps the dictionary-type accumulated experimental results with the given file name. The file is saved in the directory specified by the latest train specifications (TrainSpecs)""" |
167
|
|
|
if not self.collection_passes: |
168
|
|
|
raise DidNotReceiveTrainSignalException('Model probably hasn\'t been fitted since len(self.collection_passes) = {}'.format(len(self.collection_passes))) |
169
|
|
|
# asserts that the number of observations recorded per tracked metric variables is equal to the number of "collections pass"; training iterations over the document dataset |
170
|
|
|
# artm.ARTM.scores satisfy this |
171
|
|
|
# print 'Saving model \'{}\', train set iterations: {}'.format(self.topic_model.label, self.collection_passes) |
172
|
|
|
# assert all(map(lambda x: len(x) == sum(self.collection_passes), [values_list for eval2scoresdict in self.trackables.values() for values_list in eval2scoresdict.values()])) |
173
|
|
|
# the above will fail when metrics outside the artm library start to get tracked, because these will be able to only capture the last state of the metric trajectory due to fitting by "chunks" |
174
|
|
|
self.train_results_handler.save(self._topic_model.label) |
175
|
|
|
if save_phi: |
176
|
|
|
self.phi_matrix_handler.save(self._topic_model.label) |
177
|
|
|
|
178
|
|
|
def load_experiment(self, model_label): |
179
|
|
|
""" |
180
|
|
|
Given a unigue model label, restores the state of the experiment from disk. Loads all tracked values of the experimental results |
181
|
|
|
and the state of the TopicModel inferred so far: namely the phi p_wt matrix. |
182
|
|
|
In details loads settings:\n |
183
|
|
|
- doc collection fit iteration steady_chunks\n |
184
|
|
|
- eval metrics/measures trackes per iteration\n |
185
|
|
|
- regularization parameters\n |
186
|
|
|
- document passes\n |
187
|
|
|
:param str model_label: a unigue identifier of a topic model |
188
|
|
|
:return: the latest train specification used in the experiment |
189
|
|
|
:rtype: patm.modeling.topic_model.TopicModel |
190
|
|
|
""" |
191
|
|
|
results = self.train_results_handler.load(model_label) |
192
|
|
|
self._topic_model = self.phi_matrix_handler.load(model_label, results) |
193
|
|
|
# self._topic_model.scores = {v.name: topic_model.artm_model.score_tracker[v.name] for v in topic_model._definition2evaluator.values()} |
194
|
|
|
self.failed_top_tokens_coherence = {} |
195
|
|
|
self.collection_passes = [item for sublist in results.tracked.collection_passes for item in sublist] |
196
|
|
|
# [x for x in ] [sublist for sublist in results.tracked.collection_passes in x for x in sublist for sublist in results.tracked.collection_passes] |
197
|
|
|
self._total_passes = sum(sum(_) for _ in results.tracked.collection_passes) |
198
|
|
|
try: |
199
|
|
|
self.regularizers_dynamic_parameters = dict(results.tracked.regularization_dynamic_parameters) |
200
|
|
|
except KeyError as e: |
201
|
|
|
print(e) |
202
|
|
|
raise RuntimeError("Tracked: {}, dynamic reg params: {}".format(results.tracked, results.tracked.regularization_dynamic_parameters)) |
203
|
|
|
self.trackables = self._get_trackables(results) |
204
|
|
|
self.final_tokens = { |
205
|
|
|
'topic-kernel': {kernel_def: {t_name: list(tokens) for t_name, tokens in topics_tokens} for kernel_def, topics_tokens in list(results.final.kernel_hash.items())}, |
206
|
|
|
'top-tokens': {top_tokens_def: {t_name: list(tokens) for t_name, tokens in topics_tokens} for top_tokens_def, topics_tokens in list(results.final.top_hash.items())}, |
207
|
|
|
'background-tokens': list(results.final.background_tokens), |
208
|
|
|
} |
209
|
|
|
return self._topic_model |
210
|
|
|
|
211
|
|
|
def _get_final_tokens(self, evaluation_definition): |
212
|
|
|
return self.topic_model.artm_model.score_tracker[self.topic_model.definition2evaluator_name[evaluation_definition]].tokens[-1] |
213
|
|
|
|
214
|
|
|
|
215
|
|
|
def _get_trackables(self, results): |
216
|
|
|
""" |
217
|
|
|
:param results.experimental_results.ExperimentalResults results: |
218
|
|
|
:return: |
219
|
|
|
:rtype: dict |
220
|
|
|
""" |
221
|
|
|
trackables = {} |
222
|
|
|
for k in (_.replace('_', '-') for _ in dir(results.tracked) if _ not in ('tau_trajectories', 'regularization_dynamic_parameters')): |
223
|
|
|
if self._strip_parameters(k) in ('perplexity', 'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'): |
224
|
|
|
trackables[Experiment._assert_max_decimals(k)] = results.tracked[k].all |
225
|
|
|
elif k.startswith('topic-kernel-'): |
226
|
|
|
trackables[Experiment._assert_max_decimals(k)] = [ |
227
|
|
|
results.tracked[k].average.coherence.all, |
228
|
|
|
results.tracked[k].average.contrast.all, |
229
|
|
|
results.tracked[k].average.purity.all, |
230
|
|
|
results.tracked[k].average.size.all, |
231
|
|
|
{t_name: {'coherence': getattr(results.tracked[k], t_name).coherence.all, |
232
|
|
|
'contrast': getattr(results.tracked[k], t_name).contrast.all, |
233
|
|
|
'purity': getattr(results.tracked[k], t_name).purity.all, |
234
|
|
|
'size': getattr(results.tracked[k], t_name).size.all} for t_name in results.scalars.domain_topics} |
235
|
|
|
] |
236
|
|
|
elif k.startswith('top-tokens-'): |
237
|
|
|
trackables[k] = [results.tracked[k].average_coherence.all, |
238
|
|
|
{t_name: getattr(results.tracked[k], t_name).all for t_name in results.scalars.domain_topics}] |
239
|
|
|
return trackables |
240
|
|
|
|
241
|
|
|
|
242
|
|
|
class DegenerationChecker(object): |
243
|
|
|
def __init__(self, reference_keys): |
244
|
|
|
self._keys = sorted(reference_keys) |
245
|
|
|
|
246
|
|
|
def _initialize(self): |
247
|
|
|
self._iter = 0 |
248
|
|
|
self._degen_info = {key: [] for key in self._keys} |
249
|
|
|
self._building = {k: False for k in self._keys} |
250
|
|
|
self._li = {k: 0 for k in self._keys} |
251
|
|
|
self._ri = {k: 0 for k in self._keys} |
252
|
|
|
|
253
|
|
|
def __str__(self): |
254
|
|
|
return '[{}]'.format(', '.join(self._keys)) |
255
|
|
|
|
256
|
|
|
def __repr__(self): |
257
|
|
|
return "{}({})".format(self.__class__.__name__, str(self)) |
258
|
|
|
|
259
|
|
|
@property |
260
|
|
|
def keys(self): |
261
|
|
|
return self.keys |
262
|
|
|
|
263
|
|
|
@keys.setter |
264
|
|
|
def keys(self, keys): |
265
|
|
|
self._keys = sorted(keys) |
266
|
|
|
|
267
|
|
|
def get_degeneration_info(self, dict_list): |
268
|
|
|
self._initialize() |
269
|
|
|
return dict([(x, self.get_degenerated_tuples(x, self._get_struct(dict_list))) for x in self._keys]) |
270
|
|
|
|
271
|
|
|
# def get_degen_info_0(self, dict_list): |
272
|
|
|
# self.build(dict_list) |
273
|
|
|
# return self._degen_info |
274
|
|
|
# |
275
|
|
|
# def build(self, dict_list): |
276
|
|
|
# self._prepare_storing() |
277
|
|
|
# for k in self._keys: |
278
|
|
|
# self._build_degen_info(k, self._get_struct(dict_list)) |
279
|
|
|
# self._add_final(k, self._degen_info[k]) |
280
|
|
|
|
281
|
|
|
def get_degenerated_tuples(self, key, struct): |
282
|
|
|
""" |
283
|
|
|
:param str key: |
284
|
|
|
:param list struct: output of self._get_struct(dict_list) |
285
|
|
|
:return: a list of tuples indicating starting and finishing train iterations when the information has been degenerated (missing) |
286
|
|
|
""" |
287
|
|
|
_ = [_f for _f in [self._build_tuple(key, x[1], x[0]) for x in enumerate(struct)] if _f] |
288
|
|
|
self._add_final(key, _) |
289
|
|
|
return _ |
290
|
|
|
|
291
|
|
|
def _add_final(self, key, info): |
292
|
|
|
if self._building[key]: |
293
|
|
|
info.append((self._li[key], self._ri[key])) |
294
|
|
|
|
295
|
|
|
def _build_tuple(self, key, struct_element, iter_count): |
296
|
|
|
""" |
297
|
|
|
:param str key: |
298
|
|
|
:param list struct_element: |
299
|
|
|
:param int iter_count: |
300
|
|
|
:return: |
301
|
|
|
""" |
302
|
|
|
r = None |
303
|
|
|
if key in struct_element: # if has lost its information (has degenerated) for iteration/collection_pass i |
304
|
|
|
if self._building[key]: |
305
|
|
|
self._ri[key] += 1 |
306
|
|
|
else: |
307
|
|
|
self._li[key] = iter_count |
308
|
|
|
self._ri[key] = iter_count + 1 |
309
|
|
|
self._building[key] = True |
310
|
|
|
else: |
311
|
|
|
if self._building[key]: |
312
|
|
|
r = (self._li[key], self._ri[key]) |
313
|
|
|
self._building[key] = False |
314
|
|
|
return r |
315
|
|
|
|
316
|
|
|
def _build_degen_info(self, key, struct): |
317
|
|
|
for i, el in enumerate(struct): |
318
|
|
|
if key in el: # if has lost its information (has degenerated) for iteration/collection_pass i |
319
|
|
|
if self._building[key]: |
320
|
|
|
self._ri[key] += 1 |
321
|
|
|
else: |
322
|
|
|
self._li[key] = i |
323
|
|
|
self._ri[key] = i + 1 |
324
|
|
|
self._building[key] = True |
325
|
|
|
else: |
326
|
|
|
if self._building[key]: |
327
|
|
|
self._degen_info[key].append((self._li[key], self._ri[key])) |
328
|
|
|
self._building[key] = False |
329
|
|
|
|
330
|
|
|
def _get_struct(self, dict_list): |
331
|
|
|
return [self._get_degen_keys(self._keys, x) for x in dict_list] |
332
|
|
|
|
333
|
|
|
def _get_degen_keys(self, key_list, a_dict): |
334
|
|
|
return [_f for _f in [x if self._has_degenerated(x, a_dict) else None for x in key_list] if _f] |
335
|
|
|
|
336
|
|
|
def _has_degenerated(self, key, a_dict): |
337
|
|
|
if key not in a_dict or len(a_dict[key]) == 0 or not a_dict[key]: return True |
338
|
|
|
return False |
339
|
|
|
|
340
|
|
|
|
341
|
|
|
class EvaluationOutputLoadingException(Exception): pass |
342
|
|
|
class DidNotReceiveTrainSignalException(Exception): pass |
343
|
|
|
|