1
|
|
|
import pprint |
2
|
|
|
|
3
|
|
|
import warnings |
4
|
|
|
from collections import Counter |
5
|
|
|
|
6
|
|
|
from .regularization.regularizers_factory import REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
class TopicModel(object): |
10
|
|
|
""" |
11
|
|
|
An instance of this class encapsulates the behaviour of a topic_model. |
12
|
|
|
- LIMITATION: |
13
|
|
|
Does not support dynamic changing of the number of topics between different training cycles. |
14
|
|
|
""" |
15
|
|
|
def __init__(self, label, artm_model, evaluators): |
16
|
|
|
""" |
17
|
|
|
Creates a topic model object. The model is ready to add regularizers to it.\n |
18
|
|
|
:param str label: a unique identifier for the model; model name |
19
|
|
|
:param artm.ARTM artm_model: a reference to an artm model object |
20
|
|
|
:param dict evaluators: mapping from evaluator-definitions strings to patm.evaluation.base_evaluator.ArtmEvaluator objects |
21
|
|
|
; i.e. {'perplexity': obj0, 'sparsity-phi-\@dc': obj1, 'sparsity-phi-\@ic': obj2, 'topic-kernel-0.6': obj3, |
22
|
|
|
'topic-kernel-0.8': obj4, 'top-tokens-10': obj5} |
23
|
|
|
""" |
24
|
|
|
self.label = label |
25
|
|
|
self.artm_model = artm_model |
26
|
|
|
|
27
|
|
|
self._definition2evaluator = evaluators |
28
|
|
|
|
29
|
|
|
self.definition2evaluator_name = {k: v.name for k, v in evaluators.items()} # {'sparsity-phi': 'spd_p', 'top-tokens-10': 'tt10'} |
30
|
|
|
self._evaluator_name2definition = {v.name: k for k, v in evaluators.items()} # {'sp_p': 'sparsity-phi', 'tt10': 'top-tokens-10'} |
31
|
|
|
self._reg_longtype2name = {} # ie {'smooth-theta': 'smb_t', 'sparse-phi': 'spd_p'} |
32
|
|
|
self._reg_name2wrapper = {} |
33
|
|
|
|
34
|
|
|
def add_regularizer_wrapper(self, reg_wrapper): |
35
|
|
|
self._reg_name2wrapper[reg_wrapper.name] = reg_wrapper |
36
|
|
|
self._reg_longtype2name[reg_wrapper.long_type] = reg_wrapper.name |
37
|
|
|
self.artm_model.regularizers.add(reg_wrapper.artm_regularizer) |
38
|
|
|
|
39
|
|
|
def add_regularizer_wrappers(self, reg_wrappers): |
40
|
|
|
for _ in reg_wrappers: |
41
|
|
|
self.add_regularizer_wrapper(_) |
42
|
|
|
|
43
|
|
|
@property |
44
|
|
|
def pformat_regularizers(self): |
45
|
|
|
def _filter(reg_name): |
46
|
|
|
d = self._reg_name2wrapper[reg_name].static_parameters |
47
|
|
|
if 'topic_names' in d: |
48
|
|
|
del d['topic_names'] |
49
|
|
|
return d |
50
|
|
|
return pprint.pformat({ |
51
|
|
|
reg_long_type: dict(_filter(reg_name), |
52
|
|
|
**{k:v for k,v in {'target topics': (lambda x: 'all' if len(x) == 0 else '[{}]'. |
53
|
|
|
format(', '.join(x)))(self.get_reg_obj(reg_name).topic_names), |
54
|
|
|
'mods': getattr(self.get_reg_obj(reg_name), 'class_ids', None)}.items()} |
55
|
|
|
) |
56
|
|
|
for reg_long_type, reg_name in self._reg_longtype2name.items() |
57
|
|
|
}) |
58
|
|
|
|
59
|
|
|
@property |
60
|
|
|
def pformat_modalities(self): |
61
|
|
|
return pprint.pformat(self.modalities_dictionary) |
62
|
|
|
|
63
|
|
|
def initialize_regularizers(self, collection_passes, document_passes): |
64
|
|
|
""" |
65
|
|
|
Call this method before starting of training to build some regularizers settings from run-time parameters.\n |
66
|
|
|
:param int collection_passes: |
67
|
|
|
:param int document_passes: |
68
|
|
|
""" |
69
|
|
|
self._trajs = {} |
70
|
|
|
for reg_name, wrapper in self._reg_name2wrapper.items(): |
71
|
|
|
self._trajs[reg_name] = wrapper.get_tau_trajectory(collection_passes) |
72
|
|
|
wrapper.set_alpha_iters_trajectory(document_passes) |
73
|
|
|
|
74
|
|
|
def get_reg_wrapper(self, reg_name): |
75
|
|
|
return self._reg_name2wrapper.get(reg_name, None) |
76
|
|
|
|
77
|
|
|
@property |
78
|
|
|
def regularizer_names(self): |
79
|
|
|
return sorted(_ for _ in self.artm_model.regularizers.data) |
80
|
|
|
|
81
|
|
|
@property |
82
|
|
|
def regularizer_wrappers(self): |
83
|
|
|
return list(map(lambda x: self._reg_name2wrapper[x], self.regularizer_names)) |
84
|
|
|
|
85
|
|
|
@property |
86
|
|
|
def regularizer_types(self): |
87
|
|
|
return map(lambda x: self._reg_name2wrapper[x].type, self.regularizer_names) |
88
|
|
|
|
89
|
|
|
@property |
90
|
|
|
def regularizer_unique_types(self): |
91
|
|
|
return map(lambda x: self._reg_name2wrapper[x].long_type, self.regularizer_names) |
92
|
|
|
|
93
|
|
|
@property |
94
|
|
|
def long_types_n_types(self): |
95
|
|
|
return map(lambda x: (self._reg_name2wrapper[x].long_type, self._reg_name2wrapper[x].type), self.regularizer_names) |
96
|
|
|
|
97
|
|
|
@property |
98
|
|
|
def evaluator_names(self): |
99
|
|
|
return sorted(_ for _ in self.artm_model.scores.data) |
100
|
|
|
|
101
|
|
|
@property |
102
|
|
|
def evaluator_definitions(self): |
103
|
|
|
return [self._evaluator_name2definition[eval_name] for eval_name in self.evaluator_names] |
104
|
|
|
|
105
|
|
|
@property |
106
|
|
|
def evaluators(self): |
107
|
|
|
return [self._definition2evaluator[ev_def] for ev_def in self.evaluator_definitions] |
108
|
|
|
|
109
|
|
|
@property |
110
|
|
|
def tau_trajectories(self): |
111
|
|
|
return filter(lambda x: x[1] is not None, self._trajs.items()) |
112
|
|
|
|
113
|
|
|
@property |
114
|
|
|
def nb_topics(self): |
115
|
|
|
return self.artm_model.num_topics |
116
|
|
|
|
117
|
|
|
@property |
118
|
|
|
def topic_names(self): |
119
|
|
|
return self.artm_model.topic_names |
120
|
|
|
|
121
|
|
|
@property |
122
|
|
|
def domain_topics(self): |
123
|
|
|
"""Returns the mostly agreed list of topic names found in all evaluators""" |
124
|
|
|
c = Counter() |
125
|
|
|
for definition, evaluator in self._definition2evaluator.items(): |
126
|
|
|
tn = getattr(evaluator, 'topic_names', None) |
127
|
|
|
if tn: |
128
|
|
|
# if hasattr(evaluator.artm_score, 'topic_names'): |
129
|
|
|
# print definition, evaluator.artm_score.topic_names |
130
|
|
|
# tn = evaluator.artm_score.topic_names |
131
|
|
|
# if tn: |
132
|
|
|
c['+'.join(tn)] += 1 |
133
|
|
|
if len(c) > 2: |
134
|
|
|
warnings.warn("There exist {} different subsets of all the topic names targeted by evaluators".format(len(c))) |
135
|
|
|
# print c.most_common() |
136
|
|
|
try: |
137
|
|
|
return c.most_common(1)[0][0].split('+') |
138
|
|
|
except IndexError: |
139
|
|
|
raise IndexError("Failed to compute domain topic names. Please enable at least one score with topics=domain_names argument. Scores: {}". |
140
|
|
|
format(['{}: {}'.format(x.name, x.settings) for x in self.evaluators])) |
141
|
|
|
|
142
|
|
|
@property |
143
|
|
|
def background_topics(self): |
144
|
|
|
return [topic_name for topic_name in self.artm_model.topic_names if topic_name not in self.domain_topics] |
145
|
|
|
|
146
|
|
|
@property |
147
|
|
|
def background_tokens(self): |
148
|
|
|
background_tokens_eval_name = '' |
149
|
|
|
for eval_def, eval_name in self.definition2evaluator_name.items(): |
150
|
|
|
if eval_def.startswith('background-tokens-ratio-'): |
151
|
|
|
background_tokens_eval_name = eval_name |
152
|
|
|
if background_tokens_eval_name: |
153
|
|
|
res = self.artm_model.score_tracker[background_tokens_eval_name].tokens |
154
|
|
|
return list(res[-1]) |
155
|
|
|
|
156
|
|
|
@property |
157
|
|
|
def modalities_dictionary(self): |
158
|
|
|
return self.artm_model.class_ids |
159
|
|
|
|
160
|
|
|
@property |
161
|
|
|
def document_passes(self): |
162
|
|
|
return self.artm_model.num_document_passes |
163
|
|
|
|
164
|
|
|
def get_reg_obj(self, reg_name): |
165
|
|
|
return self.artm_model.regularizers[reg_name] |
166
|
|
|
|
167
|
|
|
def get_reg_name(self, reg_type): |
168
|
|
|
return self._reg_longtype2name[reg_type] |
169
|
|
|
# try: |
170
|
|
|
# return self._reg_longtype2name[reg_type] |
171
|
|
|
# except KeyError: |
172
|
|
|
# print '{} is not found in {}'.format(reg_type, self._reg_longtype2name) |
173
|
|
|
# import sys |
174
|
|
|
# sys.exit(1) |
175
|
|
|
|
176
|
|
|
def get_evaluator(self, eval_name): |
177
|
|
|
return self._definition2evaluator[self._evaluator_name2definition[eval_name]] |
178
|
|
|
|
179
|
|
|
# def get_evaluator_by_def(self, definition): |
180
|
|
|
# return self._definition2evaluator[definition] |
181
|
|
|
# |
182
|
|
|
# def get_scorer_by_name(self, eval_name): |
183
|
|
|
# return self.artm_model.scores[eval_name] |
184
|
|
|
|
185
|
|
|
# def get_targeted_topics_per_evaluator(self): |
186
|
|
|
# tps = [] |
187
|
|
|
# for evaluator in self.artm_model.scores.data: |
188
|
|
|
# if hasattr(self.artm_model.scores[evaluator], 'topic_names'): |
189
|
|
|
# tps.append((evaluator, self.artm_model.scores[evaluator].topic_names)) |
190
|
|
|
# return tps |
191
|
|
|
|
192
|
|
|
# def get_formated_topic_names_per_evaluator(self): |
193
|
|
|
# return 'MODEL topic names:\n{}'.format('\n'.join(map(lambda x: ' {}: [{}]'.format(x[0], ', '.join(x[1])), self.get_targeted_topics_per_evaluator()))) |
194
|
|
|
|
195
|
|
|
@document_passes.setter |
196
|
|
|
def document_passes(self, iterations): |
197
|
|
|
self.artm_model.num_document_passes = iterations |
198
|
|
|
|
199
|
|
|
|
200
|
|
|
|
201
|
|
|
def set_parameter(self, reg_name, reg_param, value): |
202
|
|
|
if reg_name in self.artm_model.regularizers.data: |
203
|
|
|
if hasattr(self.artm_model.regularizers[reg_name], reg_param): |
204
|
|
|
try: |
205
|
|
|
# self.artm_model.regularizers[reg_name].__setattr__(reg_param, parameter_name2encoder[reg_param](value)) |
206
|
|
|
self.artm_model.regularizers[reg_name].__setattr__(reg_param, value) |
207
|
|
|
except (AttributeError, TypeError) as e: |
208
|
|
|
print(e) |
209
|
|
|
else: |
210
|
|
|
raise ParameterNameNotFoundException("Regularizer '{}' with name '{}' has no attribute (parameter) '{}'".format(type(self.artm_model.regularizers[reg_name]).__name__, reg_name, reg_param)) |
211
|
|
|
else: |
212
|
|
|
raise RegularizerNameNotFoundException("Did not find a regularizer in the artm_model with name '{}'".format(reg_name)) |
213
|
|
|
|
214
|
|
|
def set_parameters(self, reg_name2param_settings): |
215
|
|
|
for reg_name, settings in reg_name2param_settings.items(): |
216
|
|
|
for param, value in settings.items(): |
217
|
|
|
self.set_parameter(reg_name, param, value) |
218
|
|
|
|
219
|
|
|
def get_regs_param_dict(self): |
220
|
|
|
""" |
221
|
|
|
Returns a mapping between the model's regularizers unique string definition (ie shown in train.cfg: eg 'smooth-phi', 'sparse-theta', |
222
|
|
|
'label-regulariation-phi-cls', 'decorrelate-phi-dom-def') and their corresponding parameters that can be dynamically changed during training (eg 'tau', 'gamma').\n |
223
|
|
|
See patm.modeling.regularization.regularizers.REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH\n |
224
|
|
|
:return: the regularizer type (str) to parameters (list of strings) mapping (str => list) |
225
|
|
|
:rtype: dict |
226
|
|
|
""" |
227
|
|
|
d = {} |
228
|
|
|
for unique_type, reg_type in self.long_types_n_types: |
229
|
|
|
d[unique_type] = {} |
230
|
|
|
cur_reg_obj = self.artm_model.regularizers[self._reg_longtype2name[unique_type]] |
231
|
|
|
for attribute_name in REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH[reg_type]: # ie for _ in ('tau', 'gamma') |
232
|
|
|
d[unique_type][attribute_name] = getattr(cur_reg_obj, attribute_name) |
233
|
|
|
return d |
234
|
|
|
# |
235
|
|
|
# def _get_header(self, max_lens, topic_names): |
236
|
|
|
# assert len(max_lens) == len(topic_names) |
237
|
|
|
# return ' - '.join(map(lambda x: '{}{}'.format(x[1], ' ' * (max_lens[x[0]] - len(name))), (j, name in enumerate(topic_names)))) |
238
|
|
|
# |
239
|
|
|
# def _get_rows(self, topic_name2tokens): |
240
|
|
|
# max_token_lens = [max(map(lambda x: len(x), topic_name2tokens[name])) for name in self.artm_model.topic_names] |
241
|
|
|
# b = '' |
242
|
|
|
# for i in range(len(topic_name2tokens.values()[0])): |
243
|
|
|
# b += ' | '.join('{} {}'.format(topic_name2tokens[name][i], (max_token_lens[j] - len(topic_name2tokens[name][i])) * ' ') for j, name in enumerate(self.artm_model.topic_names)) + '\n' |
244
|
|
|
# return b, max_token_lens |
245
|
|
|
|
246
|
|
|
|
247
|
|
|
class TrainSpecs(object): |
248
|
|
|
def __init__(self, collection_passes, reg_names, tau_trajectories): |
249
|
|
|
self._col_iter = collection_passes |
250
|
|
|
assert len(reg_names) == len(tau_trajectories) |
251
|
|
|
self._reg_name2tau_trajectory = dict(zip(reg_names, tau_trajectories)) |
252
|
|
|
# print "TRAIN SPECS", self._col_iter, map(lambda x: len(x), self._reg_name2tau_trajectory.values()) |
253
|
|
|
assert all(map(lambda x: len(x) == self._col_iter, self._reg_name2tau_trajectory.values())) |
254
|
|
|
|
255
|
|
|
def tau_trajectory(self, reg_name): |
256
|
|
|
return self._reg_name2tau_trajectory.get(reg_name, None) |
257
|
|
|
|
258
|
|
|
@property |
259
|
|
|
def tau_trajectory_list(self): |
260
|
|
|
"""Returns the list of (reg_name, tau trajectory) pairs, sorted alphabetically by regularizer name""" |
261
|
|
|
return sorted(self._reg_name2tau_trajectory.items(), key=lambda x: x[0]) |
262
|
|
|
|
263
|
|
|
@property |
264
|
|
|
def collection_passes(self): |
265
|
|
|
return self._col_iter |
266
|
|
|
|
267
|
|
|
def to_taus_slice(self, iter_count): |
268
|
|
|
return {reg_name: {'tau': trajectory[iter_count]} for reg_name, trajectory in self._reg_name2tau_trajectory.items()} |
269
|
|
|
# return dict(zip(self._reg_name2tau_trajectory.keys(), map(lambda x: x[iter_count], self._reg_name2tau_trajectory.values()))) |
270
|
|
|
|
271
|
|
|
|
272
|
|
|
class RegularizerNameNotFoundException(Exception): |
273
|
|
|
def __init__(self, msg): |
274
|
|
|
super(RegularizerNameNotFoundException, self).__init__(msg) |
275
|
|
|
|
276
|
|
|
class ParameterNameNotFoundException(Exception): |
277
|
|
|
def __init__(self, msg): |
278
|
|
|
super(ParameterNameNotFoundException, self).__init__(msg) |
279
|
|
|
|