1
|
|
|
import sys |
2
|
|
|
import json |
3
|
|
|
from math import ceil |
4
|
|
|
import attr |
5
|
|
|
from functools import reduce |
6
|
|
|
|
7
|
|
|
from collections import OrderedDict |
8
|
|
|
import re |
9
|
|
|
|
10
|
|
|
import logging |
11
|
|
|
logger = logging.getLogger(__name__) |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
|
15
|
|
|
@attr.s(cmp=True, repr=True, str=True) |
16
|
|
|
class ExperimentalResults(object): |
17
|
|
|
scalars = attr.ib(init=True, cmp=True, repr=True) |
18
|
|
|
tracked = attr.ib(init=True, cmp=True, repr=True) |
19
|
|
|
final = attr.ib(init=True, cmp=True, repr=True) |
20
|
|
|
regularizers = attr.ib(init=True, converter=lambda x: sorted(x), cmp=True, repr=True) |
21
|
|
|
reg_defs = attr.ib(init=True, cmp=True, repr=True) |
22
|
|
|
score_defs = attr.ib(init=True, cmp=True, repr=True) |
23
|
|
|
|
24
|
|
|
def __str__(self): |
25
|
|
|
return 'Scalars:\n{}\nTracked:\n{}\nFinal:\n{}\nRegularizers: {}'.format(self.scalars, self.tracked, self.final, ', '.join(self.regularizers)) |
26
|
|
|
|
27
|
|
|
def __eq__(self, other): |
28
|
|
|
return True # [self.scalars, self.regularizers, self.reg_defs, self.score_defs] == [other.scalars, other.regularizers, other.reg_defs, other.score_defs] |
29
|
|
|
|
30
|
|
|
@property |
31
|
|
|
def tracked_kernels(self): |
32
|
|
|
return [getattr(self.tracked, 'kernel'+str(x)[2:]) for x in self.tracked.kernel_thresholds] |
33
|
|
|
|
34
|
|
|
@property |
35
|
|
|
def tracked_top_tokens(self): |
36
|
|
|
return [getattr(self.tracked, 'top{}'.format(x)) for x in self.tracked.top_tokens_cardinalities] |
37
|
|
|
|
38
|
|
|
@property |
39
|
|
|
def phi_sparsities(self): |
40
|
|
|
return [getattr(self.tracked, 'sparsity_phi_'.format(x)) for x in self.tracked.modalities_initials] |
41
|
|
|
|
42
|
|
|
def to_json(self, human_redable=True): |
43
|
|
|
if human_redable: |
44
|
|
|
indent = 2 |
45
|
|
|
else: |
46
|
|
|
indent = None |
47
|
|
|
return json.dumps(self, cls=RoundTripEncoder, indent=indent) |
48
|
|
|
|
49
|
|
|
def save_as_json(self, file_path, human_redable=True, debug=False): |
50
|
|
|
if human_redable: |
51
|
|
|
indent = 2 |
52
|
|
|
else: |
53
|
|
|
indent = None |
54
|
|
|
import os |
55
|
|
|
if not os.path.isdir(os.path.dirname(file_path)): |
56
|
|
|
raise FileNotFoundError("Directory in path '{}' has not been created.".format(os.path.dirname(file_path))) |
57
|
|
|
with open(file_path, 'w') as fp: |
58
|
|
|
json.dump(self, fp, cls=RoundTripEncoder, indent=indent) |
59
|
|
|
|
60
|
|
|
@classmethod |
61
|
|
|
def _legacy(cls, key, data, factory): |
62
|
|
|
if key not in data: |
63
|
|
|
logger.info("'{}' not in experimental results dict with keys [{}]. Perhaps object is of the old format".format(key, ', '.join(data.keys()))) |
64
|
|
|
return factory(key, data) |
65
|
|
|
return data[key] |
66
|
|
|
|
67
|
|
|
@classmethod |
68
|
|
|
def create_from_json_file(cls, file_path): |
69
|
|
|
with open(file_path, 'r') as fp: |
70
|
|
|
res = json.load(fp, cls=RoundTripDecoder) |
71
|
|
|
return cls.from_dict(res) |
72
|
|
|
|
73
|
|
|
@classmethod |
74
|
|
|
def from_dict(cls, data): |
75
|
|
|
return ExperimentalResults(SteadyTrackedItems.from_dict(data), |
76
|
|
|
ValueTracker.from_dict(data), |
77
|
|
|
FinalStateEntities.from_dict(data), |
78
|
|
|
data['regularizers'], |
79
|
|
|
# data.get('reg_defs', {'t'+str(i): v[:v.index('|')] for i, v in enumerate(data['regularizers'])}), |
80
|
|
|
cls._legacy('reg_defs', data, lambda x, y: {'type-'+str(i): v[:v.index('|')] for i, v in enumerate(y['regularizers'])}), |
81
|
|
|
cls._legacy('score_defs', data, |
82
|
|
|
lambda x, y: {k: 'name-'+str(i) for i, k in |
83
|
|
|
enumerate(_ for _ in sorted(y['tracked'].keys()) if _ not in ['tau-trajectories', 'regularization-dynamic-parameters', 'collection-passes'])})) |
84
|
|
|
|
85
|
|
|
@classmethod |
86
|
|
|
def create_from_experiment(cls, experiment): |
87
|
|
|
""" |
88
|
|
|
:param patm.modeling.experiment.Experiment experiment: |
89
|
|
|
:return: |
90
|
|
|
""" |
91
|
|
|
return ExperimentalResults(SteadyTrackedItems.from_experiment(experiment), |
92
|
|
|
ValueTracker.from_experiment(experiment), |
93
|
|
|
FinalStateEntities.from_experiment(experiment), |
94
|
|
|
[x.label for x in experiment.topic_model.regularizer_wrappers], |
95
|
|
|
{k: v for k, v in zip(experiment.topic_model.regularizer_unique_types, experiment.topic_model.regularizer_names)}, |
96
|
|
|
experiment.topic_model.definition2evaluator_name) |
97
|
|
|
|
98
|
|
|
|
99
|
|
|
############### PARSER ################# |
100
|
|
|
@attr.s(slots=False) |
101
|
|
|
class StringToDictParser(object): |
102
|
|
|
"""Parses a string (if '-' found in string they are converted to '_'; any '@' is removed) trying various regexes at runtime and returns the one capturing the most information (this is simply measured by the number of entities captured""" |
103
|
|
|
regs = attr.ib(init=False, default={'score-word': r'[a-zA-Z]+', |
104
|
|
|
'score-sep': r'(?:-|_)', |
105
|
|
|
'numerical-argument': r'(?: \d+\. )? \d+', |
106
|
|
|
'modality-argument': r'[a-zA-Z]', |
107
|
|
|
'modality-argum1ent': r'@?[a-zA-Z]c?'}, converter=lambda x: dict(x, **{'score-type': r'{score-word}(?:{score-sep}{score-word})*'.format(**x)}), cmp=True, repr=True) |
108
|
|
|
normalizer = attr.ib(init=False, default={ |
109
|
|
|
'kernel': 'topic-kernel', |
110
|
|
|
'top': 'top-tokens', |
111
|
|
|
'btr': 'background-tokens-ratio' |
112
|
|
|
}, cmp=False, repr=True) |
113
|
|
|
|
114
|
|
|
def __call__(self, *args, **kwargs): |
115
|
|
|
self.string = args[0] |
116
|
|
|
self.design = kwargs.get('design', [ |
117
|
|
|
r'({score-type}) {score-sep} ({numerical-argument})', |
118
|
|
|
r'({score-type}) {score-sep} @?({modality-argument})c?$', |
119
|
|
|
r'({score-type}) ({numerical-argument})', |
120
|
|
|
r'({score-type})']) |
121
|
|
|
if kwargs.get('debug', False): |
122
|
|
|
dd = [] |
123
|
|
|
for d in self.design: |
124
|
|
|
dd.append(self.search_debug(d, args[0])) |
125
|
|
|
if kwargs.get('encode', False): |
126
|
|
|
return self.encode(max(dd, key=lambda x: len(x))) |
127
|
|
|
return max(dd, key=lambda x: len(x)) |
128
|
|
|
if kwargs.get('encode', False): |
129
|
|
|
return self.encode(max([self.search_n_dict(r, args[0]) for r in self.design], |
130
|
|
|
key=lambda x: len(x))) |
131
|
|
|
return max([self.search_n_dict(r, args[0]) for r in self.design], key=lambda x: len(x)) |
132
|
|
|
|
133
|
|
|
def search_n_dict(self, design_line, string): |
134
|
|
|
return OrderedDict([(k, v) for k, v in zip(self._entities(design_line), list(getattr(re.compile(design_line.format(**self.regs), re.X).match(string), 'groups', lambda: len(self._entities(design_line)) * [''])())) if v]) |
135
|
|
|
|
136
|
|
|
def search_debug(self, design_line, string): |
137
|
|
|
reg = re.compile(design_line.format(**self.regs), re.X) |
138
|
|
|
res = reg.search(string) |
139
|
|
|
if res: |
140
|
|
|
ls = res.groups() |
141
|
|
|
else: |
142
|
|
|
ls = len(self._entities(design_line))*[''] |
143
|
|
|
return OrderedDict([(k, v) for k , v in zip(self._entities(design_line), list(ls)) if v]) |
144
|
|
|
|
145
|
|
|
def encode(self, ord_d): |
146
|
|
|
# try: |
147
|
|
|
# |
148
|
|
|
# except KeyError: |
149
|
|
|
# raise KeyError("String '{}' could not be parsed. Dict {}".format(self.string, ord_d)) |
150
|
|
|
ord_d['score-type'] = self.normalizer.get(ord_d['score-type'].replace('_', '-'), ord_d['score-type'].replace('_', '-')) |
151
|
|
|
if 'modality-argument' in ord_d: |
152
|
|
|
ord_d['modality-argument'] = '@{}c'.format(ord_d['modality-argument'].lower()) |
153
|
|
|
if 'numerical-argument' in ord_d: |
154
|
|
|
ord_d['numerical-argument'] = self._norm_numerical(ord_d['score-type'], ord_d['numerical-argument']) |
155
|
|
|
# if ord_d['score-type'] in ['topic_kernel', 'background_tokens_ratio']: |
156
|
|
|
# # integer = int(ord_d['numerical-argument']) |
157
|
|
|
# a_float = float(ord_d['numerical-argument']) |
158
|
|
|
# if int(a_float) == a_float: |
159
|
|
|
# ord_d['numerical-argument'] = ord_d['numerical-argument'][:2] |
160
|
|
|
# else: |
161
|
|
|
# ord_d['numerical-argument'] = '{:.2f}'.format(a_float)[2:] |
162
|
|
|
return ord_d |
163
|
|
|
# |
164
|
|
|
# integer = int(ord_d['numerical-argument']) |
165
|
|
|
# a_float = float(ord_d['numerical-argument']) |
166
|
|
|
# if integer == a_float: |
167
|
|
|
# |
168
|
|
|
# return key, str(int(value)) |
169
|
|
|
# return key, '{:.2f}'.format(value) |
170
|
|
|
# ord_d['numerical-argument'] = ord_d['numerical-argument'].replace('@', '').lower() |
171
|
|
|
# return key, value.replace('@', '').lower() |
172
|
|
|
# if ord_d['score-type'] in ['top_kernel', 'background_tokens_ratio']: |
173
|
|
|
# |
174
|
|
|
# if key == 'score-type': |
175
|
|
|
# return key, value.replace('-', '_') |
176
|
|
|
# if key == 'modality-argument': |
177
|
|
|
# return key, value.replace('@', '').lower() |
178
|
|
|
# if key == 'numerical-argument': |
179
|
|
|
# integer = int(value) |
180
|
|
|
# a_float = float(value) |
181
|
|
|
# if integer == a_float: |
182
|
|
|
# return key, str(int(value)) |
183
|
|
|
# return key, '{:.2f}'.format(value) |
184
|
|
|
# |
185
|
|
|
# return key, value |
186
|
|
|
|
187
|
|
|
def _norm_numerical(self, score_type, value): |
188
|
|
|
if score_type in ['topic-kernel', 'background-tokens-ratio']: |
189
|
|
|
a_float = float(value) |
190
|
|
|
if int(a_float) == a_float: |
191
|
|
|
value = '0.' + str(int(a_float)) |
192
|
|
|
# return value[:2] # keep fist 2 characters only as the real part |
193
|
|
|
return '{:.2f}'.format(float(value)) |
194
|
|
|
return value |
195
|
|
|
|
196
|
|
|
def _entities(self, design_line): |
197
|
|
|
return self._post_process(re.findall(r''.join([r'\({', r'(?:{score-type}|{numerical-argument}|{modality-argument})'.format(**self.regs), r'}\)']), design_line)) |
198
|
|
|
|
199
|
|
|
def _post_process(self, entities): |
200
|
|
|
return [x[2:-2] for x in entities] |
201
|
|
|
|
202
|
|
|
|
203
|
|
|
################### VALUE TRACKER ################### |
204
|
|
|
|
205
|
|
|
@attr.s(cmp=True, repr=True, str=True) |
206
|
|
|
class ValueTracker(object): |
207
|
|
|
_tracked = attr.ib(init=True, cmp=True, repr=True) |
208
|
|
|
parser = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False) |
209
|
|
|
|
210
|
|
|
def __attrs_post_init__(self): |
211
|
|
|
self.scores = {} |
212
|
|
|
self._rest = {} |
213
|
|
|
|
214
|
|
|
for tracked_definition, v in list(self._tracked.items()): # assumes maximum depth is 2 |
215
|
|
|
|
216
|
|
|
d = self.parser(tracked_definition, encode=True, debug=False) |
217
|
|
|
key = '-'.join(d.values()) |
218
|
|
|
# print("DEF: {}, d: {}, key: {}".format(tracked_definition, d, key)) |
219
|
|
|
try: |
220
|
|
|
tracked_type = d['score-type'] |
221
|
|
|
except KeyError: |
222
|
|
|
raise KeyError("String '{}' was not parsed successfully. d = {}".format(tracked_definition, d)) |
223
|
|
|
if tracked_type == 'topic-kernel': |
224
|
|
|
self.scores[tracked_definition] = TrackedKernel(*v) |
225
|
|
|
elif tracked_type == 'top-tokens': |
226
|
|
|
self.scores[key] = TrackedTopTokens(*v) |
227
|
|
|
elif tracked_type in ['perplexity', 'sparsity-theta', 'background-tokens-ratio']: |
228
|
|
|
self.scores[key] = TrackedEntity(tracked_type, v) |
229
|
|
|
elif tracked_type in ['sparsity-phi']: |
230
|
|
|
self.scores[key] = TrackedEntity(tracked_type, v) |
231
|
|
|
elif tracked_type == 'tau-trajectories': |
232
|
|
|
self._rest[key] = TrackedTrajectories(v) |
233
|
|
|
elif tracked_type == 'regularization-dynamic-parameters': |
234
|
|
|
self._rest[key] = TrackedEvolvingRegParams(v) |
235
|
|
|
else: |
236
|
|
|
self._rest[key] = TrackedEntity(tracked_type, v) |
237
|
|
|
|
238
|
|
|
@classmethod |
239
|
|
|
def from_dict(cls, data): |
240
|
|
|
tracked = data['tracked'] |
241
|
|
|
d = {'perplexity': tracked['perplexity'], |
242
|
|
|
'sparsity-theta': tracked['sparsity-theta'], |
243
|
|
|
'collection-passes': tracked['collection-passes'], |
244
|
|
|
'tau-trajectories': tracked['tau-trajectories'], |
245
|
|
|
} |
246
|
|
|
if 'regularization-dynamic-parameters' in tracked: |
247
|
|
|
d['regularization-dynamic-parameters'] = tracked['regularization-dynamic-parameters'] |
248
|
|
|
else: |
249
|
|
|
logger.info("Did not find 'regularization-dynamic-parameters' in tracked values. Probably, reading from legacy formatted object") |
250
|
|
|
return ValueTracker(reduce(lambda x, y: dict(x, **y), [ |
251
|
|
|
d, |
252
|
|
|
{key: v for key, v in list(tracked.items()) if key.startswith('background-tokens-ratio')}, |
253
|
|
|
{'topic-kernel-' + key: [[value['avg_coh'], value['avg_con'], value['avg_pur'], value['size']], |
254
|
|
|
{t_name: t_data for t_name, t_data in list(value['topics'].items())}] |
255
|
|
|
for key, value in list(tracked['topic-kernel'].items())}, |
256
|
|
|
{'top-tokens-' + key: [value['avg_coh'], {t_name: t_data for t_name, t_data in list(value['topics'].items())}] |
257
|
|
|
for key, value in list(tracked['top-tokens'].items())}, |
258
|
|
|
{key: v for key, v in list(tracked.items()) if key.startswith('sparsity-phi-@')} |
259
|
|
|
])) |
260
|
|
|
|
261
|
|
|
# except KeyError as e: |
262
|
|
|
# raise TypeError("Error {}, all: [{}], scalars: [{}], tracked: [{}], final: [{}]".format(e, |
263
|
|
|
# ', '.join(sorted(data.keys())), |
264
|
|
|
# ', '.join(sorted(data['scalars'].keys())), |
265
|
|
|
# ', '.join(sorted(data['tracked'].keys())), |
266
|
|
|
# ', '.join(sorted(data['final'].keys())), |
267
|
|
|
# )) |
268
|
|
|
|
269
|
|
|
@classmethod |
270
|
|
|
def from_experiment(cls, experiment): |
271
|
|
|
return ValueTracker(reduce(lambda x, y: dict(x, **y), [ |
272
|
|
|
{'perplexity': experiment.trackables['perplexity'], |
273
|
|
|
'sparsity-theta': experiment.trackables['sparsity-theta'], |
274
|
|
|
'collection-passes': experiment.collection_passes, |
275
|
|
|
'tau-trajectories': {matrix_name: experiment.regularizers_dynamic_parameters.get('sparse-' + matrix_name, {}).get('tau', []) for matrix_name in ['theta', 'phi']}, |
276
|
|
|
'regularization-dynamic-parameters': experiment.regularizers_dynamic_parameters}, |
277
|
|
|
{key: v for key, v in list(experiment.trackables.items()) if key.startswith('background-tokens-ratio')}, |
278
|
|
|
{kernel_definition: [[value[0], value[1], value[2], value[3]], value[4]] for kernel_definition, value in |
279
|
|
|
list(experiment.trackables.items()) if kernel_definition.startswith('topic-kernel')}, |
280
|
|
|
{top_tokens_definition: [value[0], value[1]] for top_tokens_definition, value in |
281
|
|
|
list(experiment.trackables.items()) if top_tokens_definition.startswith('top-tokens-')}, |
282
|
|
|
{key: v for key, v in list(experiment.trackables.items()) if key.startswith('sparsity-phi-@')}, |
283
|
|
|
])) |
284
|
|
|
|
285
|
|
|
def __dir__(self): |
286
|
|
|
return sorted(list(self.scores.keys()) + list(self._rest.keys())) |
287
|
|
|
|
288
|
|
|
@property |
289
|
|
|
def top_tokens_cardinalities(self): |
290
|
|
|
return sorted(int(_.split('-')[-1]) for _ in list(self.scores.keys()) if _.startswith('top-tokens')) |
291
|
|
|
|
292
|
|
|
@property |
293
|
|
|
def kernel_thresholds(self): |
294
|
|
|
""" |
295
|
|
|
:return: list of strings eg ['0.60', '0.80', '0.25'] |
296
|
|
|
""" |
297
|
|
|
return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('topic-kernel')) |
298
|
|
|
|
299
|
|
|
@property |
300
|
|
|
def modalities_initials(self): |
301
|
|
|
return sorted(_.split('-')[-1][1] for _ in list(self.scores.keys()) if _.startswith('sparsity-phi')) |
302
|
|
|
|
303
|
|
|
@property |
304
|
|
|
def tracked_entity_names(self): |
305
|
|
|
return sorted(list(self.scores.keys()) + list(self._rest.keys())) |
306
|
|
|
|
307
|
|
|
@property |
308
|
|
|
def background_tokens_thresholds(self): |
309
|
|
|
return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('background-tokens-ratio')) |
310
|
|
|
|
311
|
|
|
@property |
312
|
|
|
def tau_trajectory_matrices_names(self): |
313
|
|
|
try: |
314
|
|
|
return self._rest['tau-trajectories'].matrices_names |
315
|
|
|
except KeyError: |
316
|
|
|
raise KeyError("Key 'tau-trajectories' was not found in scores [{}]. ALL: [{}]".format( |
317
|
|
|
', '.join(sorted(self.scores.keys())), ', '.join(self.tracked_entity_names))) |
318
|
|
|
|
319
|
|
View Code Duplication |
def __getitem__(self, item): |
|
|
|
|
320
|
|
|
d = self.parser(item, encode=True, debug=False) |
321
|
|
|
key = '-'.join(d.values()) |
322
|
|
|
if key in self.scores: |
323
|
|
|
return self.scores[key] |
324
|
|
|
elif key in self._rest: |
325
|
|
|
return self._rest[key] |
326
|
|
|
raise KeyError( |
327
|
|
|
"Requested item '{}', converted to '{}' but it was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format( |
328
|
|
|
item, key, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys())))) |
329
|
|
|
|
330
|
|
View Code Duplication |
def __getattr__(self, item): |
|
|
|
|
331
|
|
|
d = self.parser(item, encode=True, debug=False) |
332
|
|
|
key = '-'.join(d.values()) |
333
|
|
|
if key in self.scores: |
334
|
|
|
return self.scores[key] |
335
|
|
|
elif key in self._rest: |
336
|
|
|
return self._rest[key] |
337
|
|
|
raise AttributeError( |
338
|
|
|
"Requested item '{}', converted to '{}' after parsed as {}. It was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format( |
339
|
|
|
item, key, d, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys())))) |
340
|
|
|
|
341
|
|
|
############################################################## |
342
|
|
|
|
343
|
|
|
|
344
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True) |
345
|
|
|
class TrackedEvolvingRegParams(object): |
346
|
|
|
"""Holds regularizers_parameters data which is a dictionary: keys should be the unique regularizer definitions |
347
|
|
|
(as in train.cfg; eg 'label-regularization-phi-dom-cls'). Keys should map to dictionaries, which map strings to lists. |
348
|
|
|
These dictionaries keys should correspond to one of the supported dynamic parameters (eg 'tau', 'gamma') and each |
349
|
|
|
list should have length equal to the number of collection iterations and hold the evolution/trajectory of the parameter values |
350
|
|
|
""" |
351
|
|
|
_evolved = attr.ib(init=True, converter=lambda regularizers_params: {reg: {param: TrackedEntity(param, values_list) for param, values_list in reg_params.items()} for |
352
|
|
|
reg, reg_params in regularizers_params.items()}, cmp=True) |
353
|
|
|
regularizers_definitions = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._evolved.keys()), takes_self=True), cmp=False, repr=True) |
354
|
|
|
|
355
|
|
|
def __iter__(self): |
356
|
|
|
return ((k, v) for k, v in self._evolved.items()) |
|
|
|
|
357
|
|
|
|
358
|
|
|
def __getattr__(self, item): |
359
|
|
|
return self._evolved[item] |
360
|
|
|
|
361
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True) |
362
|
|
|
class TrackedKernel(object): |
363
|
|
|
"""data = [avg_coherence_list, avg_contrast_list, avg_purity_list, sizes_list, topic_name2elements_hash]""" |
364
|
|
|
average = attr.ib(init=True, cmp=True, converter=lambda x: KernelSubGroup(x[0], x[1], x[2], x[3])) |
365
|
|
|
_topics_data = attr.ib(converter=lambda topic_name2elements_hash: {key: KernelSubGroup(val['coherence'], val['contrast'], val['purity'], key) for key, val in list(topic_name2elements_hash.items())}, init=True, cmp=True) |
366
|
|
|
topics = attr.ib(default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True)) |
367
|
|
|
|
368
|
|
|
def __getattr__(self, item): |
369
|
|
|
if item in self._topics_data: |
370
|
|
|
return self._topics_data[item] |
371
|
|
|
raise AttributeError("Topic '{}' is not registered as tracked".format(item)) |
372
|
|
|
|
373
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True) |
374
|
|
|
class TrackedTopTokens(object): |
375
|
|
|
average_coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('average_coherence', x), cmp=True, repr=True) |
376
|
|
|
_topics_data = attr.ib(init=True, converter=lambda topic_name2coherence: {key: TrackedEntity(key, val) for key, val in list(topic_name2coherence.items())}, cmp=True, repr=True) |
377
|
|
|
topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True), cmp=False, repr=True) |
378
|
|
|
|
379
|
|
|
def __getattr__(self, item): |
380
|
|
|
if item in self._topics_data: |
381
|
|
|
return self._topics_data[item] |
382
|
|
|
raise AttributeError("Topic '{}' is not registered as tracked".format(item)) |
383
|
|
|
|
384
|
|
|
|
385
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True) |
386
|
|
|
class KernelSubGroup(object): |
387
|
|
|
"""Containes averages over topics for metrics computed on lexical kernels defined with a specific threshold [coherence, contrast, purity] and the average kernel size""" |
388
|
|
|
coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('coherence', x), cmp=True, repr=True) |
389
|
|
|
contrast = attr.ib(init=True, converter=lambda x: TrackedEntity('contrast', x), cmp=True, repr=True) |
390
|
|
|
purity = attr.ib(init=True, converter=lambda x: TrackedEntity('purity', x), cmp=True, repr=True) |
391
|
|
|
size = attr.ib(init=True, converter=lambda x: TrackedEntity('size', x), cmp=True, repr=True) |
392
|
|
|
name = attr.ib(init=True, default='', cmp=True, repr=True) |
393
|
|
|
|
394
|
|
|
|
395
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True) |
396
|
|
|
class TrackedTrajectories(object): |
397
|
|
|
_trajs = attr.ib(init=True, converter=lambda matrix_name2elements_hash: {k: TrackedEntity(k, v) for k, v in list(matrix_name2elements_hash.items())}, cmp=True, repr=True) |
398
|
|
|
trajectories = attr.ib(init=False, default=attr.Factory(lambda self: sorted(list(self._trajs.items()), key=lambda x: x[0]), takes_self=True), cmp=False, repr=False) |
399
|
|
|
matrices_names = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._trajs.keys()), takes_self=True), cmp=False, repr=True) |
400
|
|
|
|
401
|
|
|
@property |
402
|
|
|
def phi(self): |
403
|
|
|
if 'phi' in self._trajs: |
404
|
|
|
return self._trajs['phi'] |
405
|
|
|
raise AttributeError |
406
|
|
|
|
407
|
|
|
@property |
408
|
|
|
def theta(self): |
409
|
|
|
if 'theta' in self._trajs: |
410
|
|
|
return self._trajs['theta'] |
411
|
|
|
raise AttributeError |
412
|
|
|
|
413
|
|
|
def __str__(self): |
414
|
|
|
return str(self.matrices_names) |
415
|
|
|
|
416
|
|
|
|
417
|
|
|
@attr.s(cmp=True, repr=True, str=True, slots=True) |
418
|
|
|
class TrackedEntity(object): |
419
|
|
|
name = attr.ib(init=True, converter=str, cmp=True, repr=True) |
420
|
|
|
all = attr.ib(init=True, converter=list, cmp=True, repr=True) # list of elements (values tracked) |
421
|
|
|
|
422
|
|
|
@property |
423
|
|
|
def last(self): |
424
|
|
|
return self.all[-1] |
425
|
|
|
|
426
|
|
|
def __len__(self): |
427
|
|
|
return len(self.all) |
428
|
|
|
|
429
|
|
|
def __getitem__(self, item): |
430
|
|
|
if item == 'all': |
431
|
|
|
return self.all |
432
|
|
|
try: |
433
|
|
|
return self.all[item] |
434
|
|
|
except KeyError: |
435
|
|
|
raise KeyError("self: {} type of self._elements: {}".format(self, type(self.all).__name__)) |
436
|
|
|
|
437
|
|
|
######### CONSTANTS during training and between subsequent fit calls |
438
|
|
|
|
439
|
|
|
def _check_topics(self, attribute, value): |
440
|
|
|
if len(value) != len(set(value)): |
441
|
|
|
raise RuntimeError("Detected duplicates in input topics: [{}]".format(', '.join(str(x) for x in value))) |
442
|
|
|
|
443
|
|
|
def _non_overlapping(self, attribute, value): |
444
|
|
|
if any(x in value for x in self.background_topics): |
445
|
|
|
raise RuntimeError("Detected overlapping domain topics [{}] with background topics [{}]".format(', '.join(str(x) for x in value), ', '.join(str(x) for x in self.background_topics))) |
446
|
|
|
|
447
|
|
|
def background_n_domain_topics_soundness(instance, attribute, value): |
448
|
|
|
if instance.nb_topics != len(instance.background_topics) + len(value): |
449
|
|
|
raise ValueError("nb_topics should be equal to len(background_topics) + len(domain_topics). Instead, {} != {} + {}".format(instance.nb_topics, len(instance.background_topics), len(value))) |
450
|
|
|
|
451
|
|
|
|
452
|
|
|
@attr.s(repr=True, cmp=True, str=True, slots=True) |
453
|
|
|
class SteadyTrackedItems(object): |
454
|
|
|
"""Supportes only one tokens list for a specific btr threshold.""" |
455
|
|
|
dir = attr.ib(init=True, converter=str, cmp=True, repr=True) # T.O.D.O. rename to 'dataset_dir' |
456
|
|
|
model_label = attr.ib(init=True, converter=str, cmp=True, repr=True) |
457
|
|
|
dataset_iterations = attr.ib(init=True, converter=int, cmp=True, repr=True) |
458
|
|
|
nb_topics = attr.ib(init=True, converter=int, cmp=True, repr=True) |
459
|
|
|
document_passes = attr.ib(init=True, converter=int, cmp=True, repr=True) |
460
|
|
|
background_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=_check_topics) |
461
|
|
|
domain_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=[_check_topics, _non_overlapping, background_n_domain_topics_soundness]) |
462
|
|
|
background_tokens_threshold = attr.ib(init=True, converter=float, cmp=True, repr=True) |
463
|
|
|
modalities = attr.ib(init=True, converter=dict, cmp=True, repr=True) |
464
|
|
|
|
465
|
|
|
parser = StringToDictParser() |
466
|
|
|
def __dir__(self): |
467
|
|
|
return ['dir', 'model_label', 'dataset_iterations', 'nb_topics', 'document_passes', 'background_topics', 'domain_topics', 'background_tokens_threshold', 'modalities'] |
468
|
|
|
# def __str__(self): |
469
|
|
|
# return '\n'.join(['{}: {}'.format(x, getattr(self, x)) for x in dir(self)]) |
470
|
|
|
|
471
|
|
|
@classmethod |
472
|
|
|
def from_dict(cls, data): |
473
|
|
|
steady = data['scalars'] |
474
|
|
|
background_tokens_threshold = 0 # no distinction between background and "domain" tokens |
475
|
|
|
for eval_def in data['tracked']: |
476
|
|
|
if eval_def.startswith('background-tokens-ratio'): |
477
|
|
|
background_tokens_threshold = max(background_tokens_threshold, float(eval_def.split('-')[-1])) |
478
|
|
|
return SteadyTrackedItems(steady['dir'], steady['label'], data['scalars']['dataset_iterations'], steady['nb_topics'], |
479
|
|
|
steady['document_passes'], steady['background_topics'], steady['domain_topics'], background_tokens_threshold, steady['modalities']) |
480
|
|
|
|
481
|
|
|
@classmethod |
482
|
|
|
def from_experiment(cls, experiment): |
483
|
|
|
""" |
484
|
|
|
Uses the maximum threshold found amongst the defined 'background-tokens-ratio' scores.\n |
485
|
|
|
:param patm.modeling.experiment.Experiment experiment: |
486
|
|
|
:return: |
487
|
|
|
""" |
488
|
|
|
ds = [d for d in experiment.topic_model.evaluator_definitions if d.startswith('background-tokens-ratio-')] |
489
|
|
|
m = 0 |
490
|
|
|
for d in (_.split('-')[-1] for _ in ds): |
491
|
|
|
if str(d).startswith('0.'): |
492
|
|
|
m = max(m, float(d)) |
493
|
|
|
else: |
494
|
|
|
m = max(m, float('0.' + str(d))) |
495
|
|
|
|
496
|
|
|
return SteadyTrackedItems(experiment.current_root_dir, experiment.topic_model.label, experiment.dataset_iterations, |
497
|
|
|
experiment.topic_model.nb_topics, experiment.topic_model.document_passes, experiment.topic_model.background_topics, |
498
|
|
|
experiment.topic_model.domain_topics, m, experiment.topic_model.modalities_dictionary) |
499
|
|
|
|
500
|
|
|
|
501
|
|
|
@attr.s(repr=True, str=True, cmp=True, slots=True) |
502
|
|
|
class TokensList(object): |
503
|
|
|
tokens = attr.ib(init=True, converter=list, repr=True, cmp=True) |
504
|
|
|
def __len__(self): |
505
|
|
|
return len(self.tokens) |
506
|
|
|
def __getitem__(self, item): |
507
|
|
|
return self.tokens[item] |
508
|
|
|
def __iter__(self): |
509
|
|
|
return iter(self.tokens) |
510
|
|
|
def __contains__(self, item): |
511
|
|
|
return item in self.tokens |
512
|
|
|
|
513
|
|
|
|
514
|
|
|
def kernel_def2_kernel(kernel_def): |
515
|
|
|
return 'kernel'+kernel_def.split('.')[-1] |
516
|
|
|
|
517
|
|
|
def top_tokens_def2_top(top_def): |
518
|
|
|
return 'top'+top_def.split('-')[-1] |
519
|
|
|
|
520
|
|
|
@attr.s(repr=True, str=True, cmp=True, slots=True) |
521
|
|
|
class FinalStateEntities(object): |
522
|
|
|
kernel_hash = attr.ib(init=True, converter=lambda x: {kernel_def: TopicsTokens(data) for kernel_def, data in x.items()}, cmp=True, repr=False) |
523
|
|
|
top_hash = attr.ib(init=True, converter=lambda x: {top_tokens_def: TopicsTokens(data) for top_tokens_def, data in x.items()}, cmp=True, repr=False) |
524
|
|
|
_bg_tokens = attr.ib(init=True, converter=TokensList, cmp=True, repr=False) |
525
|
|
|
kernel_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.kernel_hash.keys()), takes_self=True)) |
526
|
|
|
kernels = attr.ib(init=False, default=attr.Factory(lambda self: [kernel_def2_kernel(x) for x in self.kernel_defs], takes_self=True)) |
527
|
|
|
top_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.top_hash.keys()), takes_self=True)) |
528
|
|
|
top = attr.ib(init=False, default=attr.Factory(lambda self: [top_tokens_def2_top(x) for x in self.top_defs], takes_self=True)) |
529
|
|
|
background_tokens = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._bg_tokens.tokens), takes_self=True)) |
530
|
|
|
|
531
|
|
|
parse = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False) |
532
|
|
|
|
533
|
|
|
def __getattr__(self, item): |
534
|
|
|
d = self.parse(item, encode=True, debug=False) |
535
|
|
|
key = '-'.join(d.values()) |
536
|
|
|
if key in self.kernel_hash: |
537
|
|
|
return self.kernel_hash[key] |
538
|
|
|
elif key in self.top_hash: |
539
|
|
|
return self.top_hash[key] |
540
|
|
|
raise KeyError( |
541
|
|
|
"Requested item '{}', converted to '{}' after parsed as {}. It was not found either in kernel thresholds [{}] nor in top-toknes cardinalities [{}]".format( |
542
|
|
|
item, key, d, ', '.join(sorted(self.kernel_hash.keys())), ', '.join(sorted(self.top_hash.keys())))) |
543
|
|
|
|
544
|
|
|
def __str__(self): |
545
|
|
|
return 'bg-tokens: {}\n'.format(len(self._bg_tokens)) + '\n'.join(['final state: {}\n{}'.format(y, getattr(self, y)) for y in self.top + self.kernels]) |
546
|
|
|
|
547
|
|
|
@classmethod |
548
|
|
|
def from_dict(cls, data): |
549
|
|
|
return FinalStateEntities( |
550
|
|
|
{'topic-kernel-' + threshold: tokens_hash for threshold, tokens_hash in list(data['final']['topic-kernel'].items())}, |
551
|
|
|
{'top-tokens-' + nb_tokens: tokens_hash for nb_tokens, tokens_hash in list(data['final']['top-tokens'].items())}, |
552
|
|
|
data['final']['background-tokens'] |
553
|
|
|
) |
554
|
|
|
|
555
|
|
|
@classmethod |
556
|
|
|
def from_experiment(cls, experiment): |
557
|
|
|
return FinalStateEntities(experiment.final_tokens['topic-kernel'],experiment.final_tokens['top-tokens'], experiment.final_tokens['background-tokens']) |
558
|
|
|
|
559
|
|
|
@attr.s(repr=True, cmp=True, slots=True) |
560
|
|
|
class TopicsTokens(object): |
561
|
|
|
_tokens = attr.ib(init=True, converter=lambda topic_name2tokens: {topic_name: TokensList(tokens_list) for topic_name, tokens_list in topic_name2tokens.items()}, cmp=True, repr=False) |
562
|
|
|
_nb_columns = attr.ib(init=False, default=10, cmp=False, repr=False) |
563
|
|
|
_nb_rows = attr.ib(init=False, default=attr.Factory(lambda self: int(ceil((float(len(self._tokens)) / self._nb_columns))), takes_self=True), repr=False, cmp=False) |
564
|
|
|
topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._tokens.keys()), takes_self=True), repr=True, cmp=False) |
565
|
|
|
__lens = attr.ib(init=False, default=attr.Factory(lambda self: {k : {'string': len(k), 'list': len(str(len(v)))} for k, v in self._tokens.items()}, takes_self=True), cmp=False, repr=False) |
566
|
|
|
|
567
|
|
|
def __iter__(self): |
568
|
|
|
return ((k,v) for k,v in self._tokens.items()) |
|
|
|
|
569
|
|
|
|
570
|
|
|
def __getattr__(self, item): |
571
|
|
|
if item in self._tokens: |
572
|
|
|
return self._tokens[item] |
573
|
|
|
raise AttributeError |
574
|
|
|
|
575
|
|
|
def __str__(self): |
576
|
|
|
return '\n'.join([self._get_row_string(x) for x in self._gen_rows()]) |
577
|
|
|
|
578
|
|
|
def _gen_rows(self): |
579
|
|
|
i, j = 0, 0 |
580
|
|
|
while i<len(self._tokens): |
581
|
|
|
_ = self._index(j) |
582
|
|
|
j += 1 |
583
|
|
|
i += len(_) |
584
|
|
|
yield _ |
585
|
|
|
|
586
|
|
|
def _index(self, row_nb): |
587
|
|
|
return [_f for _f in [self.topics[x] if x<len(self._tokens) else None for x in range(self._nb_columns * row_nb, self._nb_columns * row_nb + self._nb_columns)] if _f] |
588
|
|
|
|
589
|
|
|
def _get_row_string(self, row_topic_names): |
590
|
|
|
return ' '.join(['{}{}: {}{}'.format( |
591
|
|
|
x[1], (self._max(x[0], 'string')-len(x[1]))*' ', (self._max(x[0], 'list')-self.__lens[x[1]]['list'])*' ', len(self._tokens[x[1]])) for x in enumerate(row_topic_names)]) |
592
|
|
|
|
593
|
|
|
def _max(self, column_index, entity): |
594
|
|
|
return max([self.__lens[self.topics[x]][entity] for x in self._range(column_index)]) |
595
|
|
|
|
596
|
|
|
def _range(self, column_index): |
597
|
|
|
return list(range(column_index, self._get_last_index_in_column(column_index), self._nb_columns)) |
598
|
|
|
|
599
|
|
|
def _get_last_index_in_column(self, column_index): |
600
|
|
|
return self._nb_rows * self._nb_columns - self._nb_columns + column_index |
601
|
|
|
|
602
|
|
|
|
603
|
|
|
class RoundTripEncoder(json.JSONEncoder): |
604
|
|
|
|
605
|
|
|
def default(self, obj): |
606
|
|
|
if isinstance(obj, TokensList): |
607
|
|
|
return obj.tokens |
608
|
|
|
if isinstance(obj, TopicsTokens): |
609
|
|
|
return {topic_name: list(getattr(obj, topic_name)) for topic_name in obj.topics} |
610
|
|
|
if isinstance(obj, FinalStateEntities): |
611
|
|
|
return {'topic-kernel': {kernel_def.split('-')[-1]: getattr(obj, kernel_def2_kernel(kernel_def)) for kernel_def in obj.kernel_defs}, |
612
|
|
|
'top-tokens': {top_def.split('-')[-1]: getattr(obj, top_tokens_def2_top(top_def)) for top_def in obj.top_defs}, |
613
|
|
|
'background-tokens': obj.background_tokens} |
614
|
|
|
if isinstance(obj, SteadyTrackedItems): |
615
|
|
|
return {'dir': obj.dir, |
616
|
|
|
'label': obj.model_label, |
617
|
|
|
'dataset_iterations': obj.dataset_iterations, |
618
|
|
|
'nb_topics': obj.nb_topics, |
619
|
|
|
'document_passes': obj.document_passes, |
620
|
|
|
'background_topics': obj.background_topics, |
621
|
|
|
'domain_topics': obj.domain_topics, |
622
|
|
|
'modalities': obj.modalities} |
623
|
|
|
if isinstance(obj, TrackedEntity): |
624
|
|
|
return obj.all |
625
|
|
|
if isinstance(obj, TrackedTrajectories): |
626
|
|
|
return {matrix_name: tau_elements for matrix_name, tau_elements in obj.trajectories} |
627
|
|
|
if isinstance(obj, TrackedEvolvingRegParams): |
628
|
|
|
return dict(obj) |
629
|
|
|
if isinstance(obj, TrackedTopTokens): |
630
|
|
|
return {'avg_coh': obj.average_coherence, |
631
|
|
|
'topics': {topic_name: getattr(obj, topic_name).all for topic_name in obj.topics}} |
632
|
|
|
if isinstance(obj, TrackedKernel): |
633
|
|
|
return {'avg_coh': obj.average.coherence, |
634
|
|
|
'avg_con': obj.average.contrast, |
635
|
|
|
'avg_pur': obj.average.purity, |
636
|
|
|
'size': obj.average.size, |
637
|
|
|
'topics': {topic_name: {'coherence': getattr(obj, topic_name).coherence.all, |
638
|
|
|
'contrast': getattr(obj, topic_name).contrast.all, |
639
|
|
|
'purity': getattr(obj, topic_name).purity.all} for topic_name in obj.topics}} |
640
|
|
|
if isinstance(obj, ValueTracker): |
641
|
|
|
_ = {name: tracked_entity for name, tracked_entity in list(obj.scores.items()) if name not in ['tau-trajectories', 'regularization-dynamic-parameters'] or |
642
|
|
|
all(name.startswith(x) for x in ['top-tokens', 'topic-kernel'])} |
643
|
|
|
# _ = {name: tracked_entity for name, tracked_entity in list(obj._flat.items())} |
644
|
|
|
_['top-tokens'] = {k: obj.scores['top-tokens-' + str(k)] for k in obj.top_tokens_cardinalities} |
645
|
|
|
_['topic-kernel'] = {k: obj.scores['topic-kernel-' + str(k)] for k in obj.kernel_thresholds} |
646
|
|
|
_['tau-trajectories'] = {k: getattr(obj.tau_trajectories, k) for k in obj.tau_trajectory_matrices_names} |
647
|
|
|
_['collection-passes'] = obj.collection_passes, |
648
|
|
|
_['regularization-dynamic-parameters'] = obj.regularization_dynamic_parameters |
649
|
|
|
return _ |
650
|
|
|
if isinstance(obj, ExperimentalResults): |
651
|
|
|
return {'scalars': obj.scalars, |
652
|
|
|
'tracked': obj.tracked, |
653
|
|
|
'final': obj.final, |
654
|
|
|
'regularizers': obj.regularizers, |
655
|
|
|
'reg_defs': obj.reg_defs, |
656
|
|
|
'score_defs': obj.score_defs} |
657
|
|
|
return super(RoundTripEncoder, self).default(obj) |
658
|
|
|
|
659
|
|
|
|
660
|
|
|
class RoundTripDecoder(json.JSONDecoder): |
661
|
|
|
def __init__(self, *args, **kwargs): |
662
|
|
|
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) |
663
|
|
|
|
664
|
|
|
def object_hook(self, obj): |
665
|
|
|
return obj |
666
|
|
|
|