TopicsTokens.__iter__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import sys
2
import json
3
from math import ceil
4
import attr
5
from functools import reduce
6
7
from collections import OrderedDict
8
import re
9
10
import logging
11
logger = logging.getLogger(__name__)
12
13
14
15
@attr.s(cmp=True, repr=True, str=True)
16
class ExperimentalResults(object):
17
    scalars = attr.ib(init=True, cmp=True, repr=True)
18
    tracked = attr.ib(init=True, cmp=True, repr=True)
19
    final = attr.ib(init=True, cmp=True, repr=True)
20
    regularizers = attr.ib(init=True, converter=lambda x: sorted(x), cmp=True, repr=True)
21
    reg_defs = attr.ib(init=True, cmp=True, repr=True)
22
    score_defs = attr.ib(init=True, cmp=True, repr=True)
23
24
    def __str__(self):
25
        return 'Scalars:\n{}\nTracked:\n{}\nFinal:\n{}\nRegularizers: {}'.format(self.scalars, self.tracked, self.final, ', '.join(self.regularizers))
26
27
    def __eq__(self, other):
28
        return True  # [self.scalars, self.regularizers, self.reg_defs, self.score_defs] == [other.scalars, other.regularizers, other.reg_defs, other.score_defs]
29
30
    @property
31
    def tracked_kernels(self):
32
        return [getattr(self.tracked, 'kernel'+str(x)[2:]) for x in self.tracked.kernel_thresholds]
33
34
    @property
35
    def tracked_top_tokens(self):
36
        return [getattr(self.tracked, 'top{}'.format(x)) for x in self.tracked.top_tokens_cardinalities]
37
38
    @property
39
    def phi_sparsities(self):
40
        return [getattr(self.tracked, 'sparsity_phi_'.format(x)) for x in self.tracked.modalities_initials]
41
42
    def to_json(self, human_redable=True):
43
        if human_redable:
44
            indent = 2
45
        else:
46
            indent = None
47
        return json.dumps(self, cls=RoundTripEncoder, indent=indent)
48
49
    def save_as_json(self, file_path, human_redable=True, debug=False):
50
        if human_redable:
51
            indent = 2
52
        else:
53
            indent = None
54
        import os
55
        if not os.path.isdir(os.path.dirname(file_path)):
56
            raise FileNotFoundError("Directory in path '{}' has not been created.".format(os.path.dirname(file_path)))
57
        with open(file_path, 'w') as fp:
58
            json.dump(self, fp, cls=RoundTripEncoder, indent=indent)
59
60
    @classmethod
61
    def _legacy(cls, key, data, factory):
62
        if key not in data:
63
            logger.info("'{}' not in experimental results dict with keys [{}]. Perhaps object is of the old format".format(key, ', '.join(data.keys())))
64
            return factory(key, data)
65
        return data[key]
66
67
    @classmethod
68
    def create_from_json_file(cls, file_path):
69
        with open(file_path, 'r') as fp:
70
            res = json.load(fp, cls=RoundTripDecoder)
71
        return cls.from_dict(res)
72
73
    @classmethod
74
    def from_dict(cls, data):
75
        return ExperimentalResults(SteadyTrackedItems.from_dict(data),
76
                                   ValueTracker.from_dict(data),
77
                                   FinalStateEntities.from_dict(data),
78
                                   data['regularizers'],
79
                                   # data.get('reg_defs', {'t'+str(i): v[:v.index('|')] for i, v in enumerate(data['regularizers'])}),
80
                                   cls._legacy('reg_defs', data, lambda x, y: {'type-'+str(i): v[:v.index('|')] for i, v in enumerate(y['regularizers'])}),
81
                                   cls._legacy('score_defs', data,
82
                                               lambda x, y: {k: 'name-'+str(i) for i, k in
83
                                                             enumerate(_ for _ in sorted(y['tracked'].keys()) if _ not in ['tau-trajectories', 'regularization-dynamic-parameters', 'collection-passes'])}))
84
85
    @classmethod
86
    def create_from_experiment(cls, experiment):
87
        """
88
        :param patm.modeling.experiment.Experiment experiment:
89
        :return:
90
        """
91
        return ExperimentalResults(SteadyTrackedItems.from_experiment(experiment),
92
                                   ValueTracker.from_experiment(experiment),
93
                                   FinalStateEntities.from_experiment(experiment),
94
                                   [x.label for x in experiment.topic_model.regularizer_wrappers],
95
                                   {k: v for k, v in zip(experiment.topic_model.regularizer_unique_types, experiment.topic_model.regularizer_names)},
96
                                   experiment.topic_model.definition2evaluator_name)
97
98
99
############### PARSER #################
100
@attr.s(slots=False)
101
class StringToDictParser(object):
102
    """Parses a string (if '-' found in string they are converted to '_'; any '@' is removed) trying various regexes at runtime and returns the one capturing the most information (this is simply measured by the number of entities captured"""
103
    regs = attr.ib(init=False, default={'score-word': r'[a-zA-Z]+',
104
                                        'score-sep': r'(?:-|_)',
105
                                        'numerical-argument': r'(?: \d+\. )? \d+',
106
                                        'modality-argument': r'[a-zA-Z]',
107
                                        'modality-argum1ent': r'@?[a-zA-Z]c?'}, converter=lambda x: dict(x, **{'score-type': r'{score-word}(?:{score-sep}{score-word})*'.format(**x)}), cmp=True, repr=True)
108
    normalizer = attr.ib(init=False, default={
109
            'kernel': 'topic-kernel',
110
            'top': 'top-tokens',
111
            'btr': 'background-tokens-ratio'
112
    }, cmp=False, repr=True)
113
114
    def __call__(self, *args, **kwargs):
115
        self.string = args[0]
116
        self.design = kwargs.get('design', [
117
            r'({score-type}) {score-sep} ({numerical-argument})',
118
            r'({score-type}) {score-sep} @?({modality-argument})c?$',
119
            r'({score-type}) ({numerical-argument})',
120
            r'({score-type})'])
121
        if kwargs.get('debug', False):
122
            dd = []
123
            for d in self.design:
124
                dd.append(self.search_debug(d, args[0]))
125
            if kwargs.get('encode', False):
126
                return self.encode(max(dd, key=lambda x: len(x)))
127
            return max(dd, key=lambda x: len(x))
128
        if kwargs.get('encode', False):
129
            return self.encode(max([self.search_n_dict(r, args[0]) for r in self.design],
130
                       key=lambda x: len(x)))
131
        return max([self.search_n_dict(r, args[0]) for r in self.design], key=lambda x: len(x))
132
133
    def search_n_dict(self, design_line, string):
134
        return OrderedDict([(k, v) for k, v in zip(self._entities(design_line), list(getattr(re.compile(design_line.format(**self.regs), re.X).match(string), 'groups', lambda: len(self._entities(design_line)) * [''])())) if v])
135
136
    def search_debug(self, design_line, string):
137
        reg = re.compile(design_line.format(**self.regs), re.X)
138
        res = reg.search(string)
139
        if res:
140
            ls = res.groups()
141
        else:
142
            ls = len(self._entities(design_line))*['']
143
        return OrderedDict([(k, v) for k , v in zip(self._entities(design_line), list(ls)) if v])
144
145
    def encode(self, ord_d):
146
        # try:
147
        #
148
        # except KeyError:
149
        #     raise KeyError("String '{}' could not be parsed. Dict {}".format(self.string, ord_d))
150
        ord_d['score-type'] = self.normalizer.get(ord_d['score-type'].replace('_', '-'), ord_d['score-type'].replace('_', '-'))
151
        if 'modality-argument' in ord_d:
152
            ord_d['modality-argument'] = '@{}c'.format(ord_d['modality-argument'].lower())
153
        if 'numerical-argument' in ord_d:
154
            ord_d['numerical-argument'] = self._norm_numerical(ord_d['score-type'], ord_d['numerical-argument'])
155
            # if ord_d['score-type'] in ['topic_kernel', 'background_tokens_ratio']:
156
            #     # integer = int(ord_d['numerical-argument'])
157
            #     a_float = float(ord_d['numerical-argument'])
158
            #     if int(a_float) == a_float:
159
            #         ord_d['numerical-argument'] = ord_d['numerical-argument'][:2]
160
            #     else:
161
            #         ord_d['numerical-argument'] = '{:.2f}'.format(a_float)[2:]
162
        return ord_d
163
        #
164
        #     integer = int(ord_d['numerical-argument'])
165
        #     a_float = float(ord_d['numerical-argument'])
166
        #     if integer == a_float:
167
        #
168
        #         return key, str(int(value))
169
        #     return key, '{:.2f}'.format(value)
170
        #     ord_d['numerical-argument'] = ord_d['numerical-argument'].replace('@', '').lower()
171
        #     return key, value.replace('@', '').lower()
172
        # if ord_d['score-type'] in ['top_kernel', 'background_tokens_ratio']:
173
        #
174
        # if key == 'score-type':
175
        #     return key, value.replace('-', '_')
176
        # if key == 'modality-argument':
177
        #     return key, value.replace('@', '').lower()
178
        # if key == 'numerical-argument':
179
        #     integer = int(value)
180
        #     a_float = float(value)
181
        #     if integer == a_float:
182
        #         return key, str(int(value))
183
        #     return key, '{:.2f}'.format(value)
184
        #
185
        # return key, value
186
187
    def _norm_numerical(self, score_type, value):
188
        if score_type in ['topic-kernel', 'background-tokens-ratio']:
189
            a_float = float(value)
190
            if int(a_float) == a_float:
191
                value = '0.' + str(int(a_float))
192
                # return value[:2]  # keep fist 2 characters only as the real part
193
            return '{:.2f}'.format(float(value))
194
        return value
195
196
    def _entities(self, design_line):
197
        return self._post_process(re.findall(r''.join([r'\({', r'(?:{score-type}|{numerical-argument}|{modality-argument})'.format(**self.regs), r'}\)']), design_line))
198
199
    def _post_process(self, entities):
200
        return [x[2:-2] for x in entities]
201
202
203
################### VALUE TRACKER ###################
204
205
@attr.s(cmp=True, repr=True, str=True)
206
class ValueTracker(object):
207
    _tracked = attr.ib(init=True, cmp=True, repr=True)
208
    parser = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False)
209
210
    def __attrs_post_init__(self):
211
        self.scores = {}
212
        self._rest = {}
213
214
        for tracked_definition, v in list(self._tracked.items()):  # assumes maximum depth is 2
215
216
            d = self.parser(tracked_definition, encode=True, debug=False)
217
            key = '-'.join(d.values())
218
            # print("DEF: {}, d: {}, key: {}".format(tracked_definition, d, key))
219
            try:
220
                tracked_type = d['score-type']
221
            except KeyError:
222
                raise KeyError("String '{}' was not parsed successfully. d = {}".format(tracked_definition, d))
223
            if tracked_type == 'topic-kernel':
224
                self.scores[tracked_definition] = TrackedKernel(*v)
225
            elif tracked_type == 'top-tokens':
226
                self.scores[key] = TrackedTopTokens(*v)
227
            elif tracked_type in ['perplexity', 'sparsity-theta', 'background-tokens-ratio']:
228
                self.scores[key] = TrackedEntity(tracked_type, v)
229
            elif tracked_type in ['sparsity-phi']:
230
                self.scores[key] = TrackedEntity(tracked_type, v)
231
            elif tracked_type == 'tau-trajectories':
232
                self._rest[key] = TrackedTrajectories(v)
233
            elif tracked_type == 'regularization-dynamic-parameters':
234
                self._rest[key] = TrackedEvolvingRegParams(v)
235
            else:
236
                self._rest[key] = TrackedEntity(tracked_type, v)
237
238
    @classmethod
239
    def from_dict(cls, data):
240
        tracked = data['tracked']
241
        d = {'perplexity': tracked['perplexity'],
242
             'sparsity-theta': tracked['sparsity-theta'],
243
             'collection-passes': tracked['collection-passes'],
244
             'tau-trajectories': tracked['tau-trajectories'],
245
             }
246
        if 'regularization-dynamic-parameters' in tracked:
247
            d['regularization-dynamic-parameters'] = tracked['regularization-dynamic-parameters']
248
        else:
249
            logger.info("Did not find 'regularization-dynamic-parameters' in tracked values. Probably, reading from legacy formatted object")
250
        return ValueTracker(reduce(lambda x, y: dict(x, **y), [
251
            d,
252
            {key: v for key, v in list(tracked.items()) if key.startswith('background-tokens-ratio')},
253
            {'topic-kernel-' + key: [[value['avg_coh'], value['avg_con'], value['avg_pur'], value['size']],
254
                                     {t_name: t_data for t_name, t_data in list(value['topics'].items())}]
255
             for key, value in list(tracked['topic-kernel'].items())},
256
            {'top-tokens-' + key: [value['avg_coh'], {t_name: t_data for t_name, t_data in list(value['topics'].items())}]
257
                 for key, value in list(tracked['top-tokens'].items())},
258
            {key: v for key, v in list(tracked.items()) if key.startswith('sparsity-phi-@')}
259
        ]))
260
261
        # except KeyError as e:
262
        #     raise TypeError("Error {}, all: [{}], scalars: [{}], tracked: [{}], final: [{}]".format(e,
263
        #         ', '.join(sorted(data.keys())),
264
        #         ', '.join(sorted(data['scalars'].keys())),
265
        #         ', '.join(sorted(data['tracked'].keys())),
266
        #         ', '.join(sorted(data['final'].keys())),
267
        #     ))
268
269
    @classmethod
270
    def from_experiment(cls, experiment):
271
        return ValueTracker(reduce(lambda x, y: dict(x, **y), [
272
            {'perplexity': experiment.trackables['perplexity'],
273
             'sparsity-theta': experiment.trackables['sparsity-theta'],
274
             'collection-passes': experiment.collection_passes,
275
             'tau-trajectories': {matrix_name: experiment.regularizers_dynamic_parameters.get('sparse-' + matrix_name, {}).get('tau', []) for matrix_name in ['theta', 'phi']},
276
             'regularization-dynamic-parameters': experiment.regularizers_dynamic_parameters},
277
            {key: v for key, v in list(experiment.trackables.items()) if key.startswith('background-tokens-ratio')},
278
            {kernel_definition: [[value[0], value[1], value[2], value[3]], value[4]] for kernel_definition, value in
279
                 list(experiment.trackables.items()) if kernel_definition.startswith('topic-kernel')},
280
            {top_tokens_definition: [value[0], value[1]] for top_tokens_definition, value in
281
                 list(experiment.trackables.items()) if top_tokens_definition.startswith('top-tokens-')},
282
            {key: v for key, v in list(experiment.trackables.items()) if key.startswith('sparsity-phi-@')},
283
        ]))
284
285
    def __dir__(self):
286
        return sorted(list(self.scores.keys()) + list(self._rest.keys()))
287
288
    @property
289
    def top_tokens_cardinalities(self):
290
        return sorted(int(_.split('-')[-1]) for _ in list(self.scores.keys()) if _.startswith('top-tokens'))
291
292
    @property
293
    def kernel_thresholds(self):
294
        """
295
        :return: list of strings eg ['0.60', '0.80', '0.25']
296
        """
297
        return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('topic-kernel'))
298
299
    @property
300
    def modalities_initials(self):
301
        return sorted(_.split('-')[-1][1] for _ in list(self.scores.keys()) if _.startswith('sparsity-phi'))
302
303
    @property
304
    def tracked_entity_names(self):
305
        return sorted(list(self.scores.keys()) + list(self._rest.keys()))
306
307
    @property
308
    def background_tokens_thresholds(self):
309
        return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('background-tokens-ratio'))
310
311
    @property
312
    def tau_trajectory_matrices_names(self):
313
        try:
314
            return self._rest['tau-trajectories'].matrices_names
315
        except KeyError:
316
            raise KeyError("Key 'tau-trajectories' was not found in scores [{}]. ALL: [{}]".format(
317
                ', '.join(sorted(self.scores.keys())), ', '.join(self.tracked_entity_names)))
318
319 View Code Duplication
    def __getitem__(self, item):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
320
        d = self.parser(item, encode=True, debug=False)
321
        key = '-'.join(d.values())
322
        if key in self.scores:
323
            return self.scores[key]
324
        elif key in self._rest:
325
            return self._rest[key]
326
        raise KeyError(
327
            "Requested item '{}', converted to '{}' but it was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format(
328
                item, key, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys()))))
329
330 View Code Duplication
    def __getattr__(self, item):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
331
        d = self.parser(item, encode=True, debug=False)
332
        key = '-'.join(d.values())
333
        if key in self.scores:
334
            return self.scores[key]
335
        elif key in self._rest:
336
            return self._rest[key]
337
        raise AttributeError(
338
            "Requested item '{}', converted to '{}' after parsed as {}. It was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format(
339
                item, key, d, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys()))))
340
341
##############################################################
342
343
344
@attr.s(cmp=True, repr=True, str=True, slots=True)
345
class TrackedEvolvingRegParams(object):
346
    """Holds regularizers_parameters data which is a dictionary: keys should be the unique regularizer definitions
347
    (as in train.cfg; eg 'label-regularization-phi-dom-cls'). Keys should map to dictionaries, which map strings to lists.
348
    These dictionaries keys should correspond to one of the supported dynamic parameters (eg 'tau', 'gamma') and each
349
    list should have length equal to the number of collection iterations and hold the evolution/trajectory of the parameter values
350
    """
351
    _evolved = attr.ib(init=True, converter=lambda regularizers_params: {reg: {param: TrackedEntity(param, values_list) for param, values_list in reg_params.items()} for
352
                     reg, reg_params in regularizers_params.items()}, cmp=True)
353
    regularizers_definitions = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._evolved.keys()), takes_self=True), cmp=False, repr=True)
354
355
    def __iter__(self):
356
        return ((k, v) for k, v in self._evolved.items())
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable k does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable v does not seem to be defined.
Loading history...
357
358
    def __getattr__(self, item):
359
        return self._evolved[item]
360
361
@attr.s(cmp=True, repr=True, str=True, slots=True)
362
class TrackedKernel(object):
363
    """data = [avg_coherence_list, avg_contrast_list, avg_purity_list, sizes_list, topic_name2elements_hash]"""
364
    average = attr.ib(init=True, cmp=True, converter=lambda x: KernelSubGroup(x[0], x[1], x[2], x[3]))
365
    _topics_data = attr.ib(converter=lambda topic_name2elements_hash: {key: KernelSubGroup(val['coherence'], val['contrast'], val['purity'], key) for key, val in list(topic_name2elements_hash.items())}, init=True, cmp=True)
366
    topics = attr.ib(default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True))
367
368
    def __getattr__(self, item):
369
        if item in self._topics_data:
370
            return self._topics_data[item]
371
        raise AttributeError("Topic '{}' is not registered as tracked".format(item))
372
373
@attr.s(cmp=True, repr=True, str=True, slots=True)
374
class TrackedTopTokens(object):
375
    average_coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('average_coherence', x), cmp=True, repr=True)
376
    _topics_data = attr.ib(init=True, converter=lambda topic_name2coherence: {key: TrackedEntity(key, val) for key, val in list(topic_name2coherence.items())}, cmp=True, repr=True)
377
    topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True), cmp=False, repr=True)
378
379
    def __getattr__(self, item):
380
        if item in self._topics_data:
381
            return self._topics_data[item]
382
        raise AttributeError("Topic '{}' is not registered as tracked".format(item))
383
384
385
@attr.s(cmp=True, repr=True, str=True, slots=True)
386
class KernelSubGroup(object):
387
    """Containes averages over topics for metrics computed on lexical kernels defined with a specific threshold [coherence, contrast, purity] and the average kernel size"""
388
    coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('coherence', x), cmp=True, repr=True)
389
    contrast = attr.ib(init=True, converter=lambda x: TrackedEntity('contrast', x), cmp=True, repr=True)
390
    purity = attr.ib(init=True, converter=lambda x: TrackedEntity('purity', x), cmp=True, repr=True)
391
    size = attr.ib(init=True, converter=lambda x: TrackedEntity('size', x), cmp=True, repr=True)
392
    name = attr.ib(init=True, default='', cmp=True, repr=True)
393
394
395
@attr.s(cmp=True, repr=True, str=True, slots=True)
396
class TrackedTrajectories(object):
397
    _trajs = attr.ib(init=True, converter=lambda matrix_name2elements_hash: {k: TrackedEntity(k, v) for k, v in list(matrix_name2elements_hash.items())}, cmp=True, repr=True)
398
    trajectories = attr.ib(init=False, default=attr.Factory(lambda self: sorted(list(self._trajs.items()), key=lambda x: x[0]), takes_self=True), cmp=False, repr=False)
399
    matrices_names = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._trajs.keys()), takes_self=True), cmp=False, repr=True)
400
401
    @property
402
    def phi(self):
403
        if 'phi' in self._trajs:
404
            return self._trajs['phi']
405
        raise AttributeError
406
407
    @property
408
    def theta(self):
409
        if 'theta' in self._trajs:
410
            return self._trajs['theta']
411
        raise AttributeError
412
413
    def __str__(self):
414
        return str(self.matrices_names)
415
416
417
@attr.s(cmp=True, repr=True, str=True, slots=True)
418
class TrackedEntity(object):
419
    name = attr.ib(init=True, converter=str, cmp=True, repr=True)
420
    all = attr.ib(init=True, converter=list, cmp=True, repr=True)  # list of elements (values tracked)
421
422
    @property
423
    def last(self):
424
        return self.all[-1]
425
426
    def __len__(self):
427
        return len(self.all)
428
429
    def __getitem__(self, item):
430
        if item == 'all':
431
            return self.all
432
        try:
433
            return self.all[item]
434
        except KeyError:
435
            raise KeyError("self: {} type of self._elements: {}".format(self, type(self.all).__name__))
436
437
######### CONSTANTS during training and between subsequent fit calls
438
439
def _check_topics(self, attribute, value):
440
    if len(value) != len(set(value)):
441
        raise RuntimeError("Detected duplicates in input topics: [{}]".format(', '.join(str(x) for x in value)))
442
443
def _non_overlapping(self, attribute, value):
444
    if any(x in value for x in self.background_topics):
445
        raise RuntimeError("Detected overlapping domain topics [{}] with background topics [{}]".format(', '.join(str(x) for x in value), ', '.join(str(x) for x in self.background_topics)))
446
447
def background_n_domain_topics_soundness(instance, attribute, value):
448
    if instance.nb_topics != len(instance.background_topics) + len(value):
449
        raise ValueError("nb_topics should be equal to len(background_topics) + len(domain_topics). Instead, {} != {} + {}".format(instance.nb_topics, len(instance.background_topics), len(value)))
450
451
452
@attr.s(repr=True, cmp=True, str=True, slots=True)
453
class SteadyTrackedItems(object):
454
    """Supportes only one tokens list for a specific btr threshold."""
455
    dir = attr.ib(init=True, converter=str, cmp=True, repr=True)  # T.O.D.O. rename to 'dataset_dir'
456
    model_label = attr.ib(init=True, converter=str, cmp=True, repr=True)
457
    dataset_iterations = attr.ib(init=True, converter=int, cmp=True, repr=True)
458
    nb_topics = attr.ib(init=True, converter=int, cmp=True, repr=True)
459
    document_passes = attr.ib(init=True, converter=int, cmp=True, repr=True)
460
    background_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=_check_topics)
461
    domain_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=[_check_topics, _non_overlapping, background_n_domain_topics_soundness])
462
    background_tokens_threshold = attr.ib(init=True, converter=float, cmp=True, repr=True)
463
    modalities = attr.ib(init=True, converter=dict, cmp=True, repr=True)
464
465
    parser = StringToDictParser()
466
    def __dir__(self):
467
        return ['dir', 'model_label', 'dataset_iterations', 'nb_topics', 'document_passes', 'background_topics', 'domain_topics', 'background_tokens_threshold', 'modalities']
468
    # def __str__(self):
469
    #     return '\n'.join(['{}: {}'.format(x, getattr(self, x)) for x in dir(self)])
470
471
    @classmethod
472
    def from_dict(cls, data):
473
        steady = data['scalars']
474
        background_tokens_threshold = 0  # no distinction between background and "domain" tokens
475
        for eval_def in data['tracked']:
476
            if eval_def.startswith('background-tokens-ratio'):
477
                background_tokens_threshold = max(background_tokens_threshold, float(eval_def.split('-')[-1]))
478
        return SteadyTrackedItems(steady['dir'], steady['label'], data['scalars']['dataset_iterations'], steady['nb_topics'],
479
                                  steady['document_passes'], steady['background_topics'], steady['domain_topics'], background_tokens_threshold, steady['modalities'])
480
481
    @classmethod
482
    def from_experiment(cls, experiment):
483
        """
484
        Uses the maximum threshold found amongst the defined 'background-tokens-ratio' scores.\n
485
        :param patm.modeling.experiment.Experiment experiment:
486
        :return:
487
        """
488
        ds = [d for d in experiment.topic_model.evaluator_definitions if d.startswith('background-tokens-ratio-')]
489
        m = 0
490
        for d in (_.split('-')[-1] for _ in ds):
491
            if str(d).startswith('0.'):
492
                m = max(m, float(d))
493
            else:
494
                m = max(m, float('0.' + str(d)))
495
496
        return SteadyTrackedItems(experiment.current_root_dir, experiment.topic_model.label, experiment.dataset_iterations,
497
                                  experiment.topic_model.nb_topics, experiment.topic_model.document_passes, experiment.topic_model.background_topics,
498
                                  experiment.topic_model.domain_topics, m, experiment.topic_model.modalities_dictionary)
499
500
501
@attr.s(repr=True, str=True, cmp=True, slots=True)
502
class TokensList(object):
503
    tokens = attr.ib(init=True, converter=list, repr=True, cmp=True)
504
    def __len__(self):
505
        return len(self.tokens)
506
    def __getitem__(self, item):
507
        return self.tokens[item]
508
    def __iter__(self):
509
        return iter(self.tokens)
510
    def __contains__(self, item):
511
        return item in self.tokens
512
513
514
def kernel_def2_kernel(kernel_def):
515
    return 'kernel'+kernel_def.split('.')[-1]
516
517
def top_tokens_def2_top(top_def):
518
    return 'top'+top_def.split('-')[-1]
519
520
@attr.s(repr=True, str=True, cmp=True, slots=True)
521
class FinalStateEntities(object):
522
    kernel_hash = attr.ib(init=True, converter=lambda x: {kernel_def: TopicsTokens(data) for kernel_def, data in x.items()}, cmp=True, repr=False)
523
    top_hash = attr.ib(init=True, converter=lambda x: {top_tokens_def: TopicsTokens(data) for top_tokens_def, data in x.items()}, cmp=True, repr=False)
524
    _bg_tokens = attr.ib(init=True, converter=TokensList, cmp=True, repr=False)
525
    kernel_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.kernel_hash.keys()), takes_self=True))
526
    kernels = attr.ib(init=False, default=attr.Factory(lambda self: [kernel_def2_kernel(x) for x in self.kernel_defs], takes_self=True))
527
    top_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.top_hash.keys()), takes_self=True))
528
    top = attr.ib(init=False, default=attr.Factory(lambda self: [top_tokens_def2_top(x) for x in self.top_defs], takes_self=True))
529
    background_tokens = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._bg_tokens.tokens), takes_self=True))
530
531
    parse = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False)
532
533
    def __getattr__(self, item):
534
        d = self.parse(item, encode=True, debug=False)
535
        key = '-'.join(d.values())
536
        if key in self.kernel_hash:
537
            return self.kernel_hash[key]
538
        elif key in self.top_hash:
539
            return self.top_hash[key]
540
        raise KeyError(
541
            "Requested item '{}', converted to '{}' after parsed as {}. It was not found either in kernel thresholds [{}] nor in top-toknes cardinalities [{}]".format(
542
                item, key, d, ', '.join(sorted(self.kernel_hash.keys())), ', '.join(sorted(self.top_hash.keys()))))
543
544
    def __str__(self):
545
        return 'bg-tokens: {}\n'.format(len(self._bg_tokens)) + '\n'.join(['final state: {}\n{}'.format(y, getattr(self, y)) for y in self.top + self.kernels])
546
547
    @classmethod
548
    def from_dict(cls, data):
549
        return FinalStateEntities(
550
            {'topic-kernel-' + threshold: tokens_hash for threshold, tokens_hash in list(data['final']['topic-kernel'].items())},
551
            {'top-tokens-' + nb_tokens: tokens_hash for nb_tokens, tokens_hash in list(data['final']['top-tokens'].items())},
552
            data['final']['background-tokens']
553
        )
554
555
    @classmethod
556
    def from_experiment(cls, experiment):
557
        return FinalStateEntities(experiment.final_tokens['topic-kernel'],experiment.final_tokens['top-tokens'], experiment.final_tokens['background-tokens'])
558
559
@attr.s(repr=True, cmp=True, slots=True)
560
class TopicsTokens(object):
561
    _tokens = attr.ib(init=True, converter=lambda topic_name2tokens: {topic_name: TokensList(tokens_list) for topic_name, tokens_list in topic_name2tokens.items()}, cmp=True, repr=False)
562
    _nb_columns = attr.ib(init=False, default=10, cmp=False, repr=False)
563
    _nb_rows = attr.ib(init=False, default=attr.Factory(lambda self: int(ceil((float(len(self._tokens)) / self._nb_columns))), takes_self=True), repr=False, cmp=False)
564
    topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._tokens.keys()), takes_self=True), repr=True, cmp=False)
565
    __lens = attr.ib(init=False, default=attr.Factory(lambda self: {k : {'string': len(k), 'list': len(str(len(v)))} for k, v in self._tokens.items()}, takes_self=True), cmp=False, repr=False)
566
567
    def __iter__(self):
568
        return ((k,v) for k,v in self._tokens.items())
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable v does not seem to be defined.
Loading history...
Comprehensibility Best Practice introduced by
The variable k does not seem to be defined.
Loading history...
569
570
    def __getattr__(self, item):
571
        if item in self._tokens:
572
            return self._tokens[item]
573
        raise AttributeError
574
575
    def __str__(self):
576
        return '\n'.join([self._get_row_string(x) for x in self._gen_rows()])
577
578
    def _gen_rows(self):
579
        i, j = 0, 0
580
        while i<len(self._tokens):
581
            _ = self._index(j)
582
            j += 1
583
            i += len(_)
584
            yield _
585
586
    def _index(self, row_nb):
587
        return [_f for _f in [self.topics[x] if x<len(self._tokens) else None for x in range(self._nb_columns * row_nb, self._nb_columns * row_nb + self._nb_columns)] if _f]
588
589
    def _get_row_string(self, row_topic_names):
590
        return '  '.join(['{}{}: {}{}'.format(
591
            x[1], (self._max(x[0], 'string')-len(x[1]))*' ', (self._max(x[0], 'list')-self.__lens[x[1]]['list'])*' ', len(self._tokens[x[1]])) for x in enumerate(row_topic_names)])
592
593
    def _max(self, column_index, entity):
594
        return max([self.__lens[self.topics[x]][entity] for x in self._range(column_index)])
595
596
    def _range(self, column_index):
597
        return list(range(column_index, self._get_last_index_in_column(column_index), self._nb_columns))
598
599
    def _get_last_index_in_column(self, column_index):
600
        return self._nb_rows * self._nb_columns - self._nb_columns + column_index
601
602
603
class RoundTripEncoder(json.JSONEncoder):
604
605
    def default(self, obj):
606
        if isinstance(obj, TokensList):
607
            return obj.tokens
608
        if isinstance(obj, TopicsTokens):
609
            return {topic_name: list(getattr(obj, topic_name)) for topic_name in obj.topics}
610
        if isinstance(obj, FinalStateEntities):
611
            return {'topic-kernel': {kernel_def.split('-')[-1]: getattr(obj, kernel_def2_kernel(kernel_def)) for kernel_def in obj.kernel_defs},
612
                    'top-tokens': {top_def.split('-')[-1]: getattr(obj, top_tokens_def2_top(top_def)) for top_def in obj.top_defs},
613
                    'background-tokens': obj.background_tokens}
614
        if isinstance(obj, SteadyTrackedItems):
615
            return {'dir': obj.dir,
616
                    'label': obj.model_label,
617
                    'dataset_iterations': obj.dataset_iterations,
618
                    'nb_topics': obj.nb_topics,
619
                    'document_passes': obj.document_passes,
620
                    'background_topics': obj.background_topics,
621
                    'domain_topics': obj.domain_topics,
622
                    'modalities': obj.modalities}
623
        if isinstance(obj, TrackedEntity):
624
            return obj.all
625
        if isinstance(obj, TrackedTrajectories):
626
            return {matrix_name: tau_elements for matrix_name, tau_elements in obj.trajectories}
627
        if isinstance(obj, TrackedEvolvingRegParams):
628
            return dict(obj)
629
        if isinstance(obj, TrackedTopTokens):
630
            return {'avg_coh': obj.average_coherence,
631
                    'topics': {topic_name: getattr(obj, topic_name).all for topic_name in obj.topics}}
632
        if isinstance(obj, TrackedKernel):
633
            return {'avg_coh': obj.average.coherence,
634
                    'avg_con': obj.average.contrast,
635
                    'avg_pur': obj.average.purity,
636
                    'size': obj.average.size,
637
                    'topics': {topic_name: {'coherence': getattr(obj, topic_name).coherence.all,
638
                                            'contrast': getattr(obj, topic_name).contrast.all,
639
                                            'purity': getattr(obj, topic_name).purity.all} for topic_name in obj.topics}}
640
        if isinstance(obj, ValueTracker):
641
            _ = {name: tracked_entity for name, tracked_entity in list(obj.scores.items()) if name not in ['tau-trajectories', 'regularization-dynamic-parameters'] or
642
                 all(name.startswith(x) for x in ['top-tokens', 'topic-kernel'])}
643
            # _ = {name: tracked_entity for name, tracked_entity in list(obj._flat.items())}
644
            _['top-tokens'] = {k: obj.scores['top-tokens-' + str(k)] for k in obj.top_tokens_cardinalities}
645
            _['topic-kernel'] = {k: obj.scores['topic-kernel-' + str(k)] for k in obj.kernel_thresholds}
646
            _['tau-trajectories'] = {k: getattr(obj.tau_trajectories, k) for k in obj.tau_trajectory_matrices_names}
647
            _['collection-passes'] = obj.collection_passes,
648
            _['regularization-dynamic-parameters'] = obj.regularization_dynamic_parameters
649
            return _
650
        if isinstance(obj, ExperimentalResults):
651
            return {'scalars': obj.scalars,
652
                    'tracked': obj.tracked,
653
                    'final': obj.final,
654
                    'regularizers': obj.regularizers,
655
                    'reg_defs': obj.reg_defs,
656
                    'score_defs': obj.score_defs}
657
        return super(RoundTripEncoder, self).default(obj)
658
659
660
class RoundTripDecoder(json.JSONDecoder):
661
    def __init__(self, *args, **kwargs):
662
        json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
663
664
    def object_hook(self, obj):
665
        return obj
666