Total Complexity | 135 |
Total Lines | 666 |
Duplicated Lines | 3 % |
Changes | 0 |
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like results.experimental_results often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import sys |
||
2 | import json |
||
3 | from math import ceil |
||
4 | import attr |
||
5 | from functools import reduce |
||
6 | |||
7 | from collections import OrderedDict |
||
8 | import re |
||
9 | |||
10 | import logging |
||
11 | logger = logging.getLogger(__name__) |
||
12 | |||
13 | |||
14 | |||
15 | @attr.s(cmp=True, repr=True, str=True) |
||
16 | class ExperimentalResults(object): |
||
17 | scalars = attr.ib(init=True, cmp=True, repr=True) |
||
18 | tracked = attr.ib(init=True, cmp=True, repr=True) |
||
19 | final = attr.ib(init=True, cmp=True, repr=True) |
||
20 | regularizers = attr.ib(init=True, converter=lambda x: sorted(x), cmp=True, repr=True) |
||
21 | reg_defs = attr.ib(init=True, cmp=True, repr=True) |
||
22 | score_defs = attr.ib(init=True, cmp=True, repr=True) |
||
23 | |||
24 | def __str__(self): |
||
25 | return 'Scalars:\n{}\nTracked:\n{}\nFinal:\n{}\nRegularizers: {}'.format(self.scalars, self.tracked, self.final, ', '.join(self.regularizers)) |
||
26 | |||
27 | def __eq__(self, other): |
||
28 | return True # [self.scalars, self.regularizers, self.reg_defs, self.score_defs] == [other.scalars, other.regularizers, other.reg_defs, other.score_defs] |
||
29 | |||
30 | @property |
||
31 | def tracked_kernels(self): |
||
32 | return [getattr(self.tracked, 'kernel'+str(x)[2:]) for x in self.tracked.kernel_thresholds] |
||
33 | |||
34 | @property |
||
35 | def tracked_top_tokens(self): |
||
36 | return [getattr(self.tracked, 'top{}'.format(x)) for x in self.tracked.top_tokens_cardinalities] |
||
37 | |||
38 | @property |
||
39 | def phi_sparsities(self): |
||
40 | return [getattr(self.tracked, 'sparsity_phi_'.format(x)) for x in self.tracked.modalities_initials] |
||
41 | |||
42 | def to_json(self, human_redable=True): |
||
43 | if human_redable: |
||
44 | indent = 2 |
||
45 | else: |
||
46 | indent = None |
||
47 | return json.dumps(self, cls=RoundTripEncoder, indent=indent) |
||
48 | |||
49 | def save_as_json(self, file_path, human_redable=True, debug=False): |
||
50 | if human_redable: |
||
51 | indent = 2 |
||
52 | else: |
||
53 | indent = None |
||
54 | import os |
||
55 | if not os.path.isdir(os.path.dirname(file_path)): |
||
56 | raise FileNotFoundError("Directory in path '{}' has not been created.".format(os.path.dirname(file_path))) |
||
57 | with open(file_path, 'w') as fp: |
||
58 | json.dump(self, fp, cls=RoundTripEncoder, indent=indent) |
||
59 | |||
60 | @classmethod |
||
61 | def _legacy(cls, key, data, factory): |
||
62 | if key not in data: |
||
63 | logger.info("'{}' not in experimental results dict with keys [{}]. Perhaps object is of the old format".format(key, ', '.join(data.keys()))) |
||
64 | return factory(key, data) |
||
65 | return data[key] |
||
66 | |||
67 | @classmethod |
||
68 | def create_from_json_file(cls, file_path): |
||
69 | with open(file_path, 'r') as fp: |
||
70 | res = json.load(fp, cls=RoundTripDecoder) |
||
71 | return cls.from_dict(res) |
||
72 | |||
73 | @classmethod |
||
74 | def from_dict(cls, data): |
||
75 | return ExperimentalResults(SteadyTrackedItems.from_dict(data), |
||
76 | ValueTracker.from_dict(data), |
||
77 | FinalStateEntities.from_dict(data), |
||
78 | data['regularizers'], |
||
79 | # data.get('reg_defs', {'t'+str(i): v[:v.index('|')] for i, v in enumerate(data['regularizers'])}), |
||
80 | cls._legacy('reg_defs', data, lambda x, y: {'type-'+str(i): v[:v.index('|')] for i, v in enumerate(y['regularizers'])}), |
||
81 | cls._legacy('score_defs', data, |
||
82 | lambda x, y: {k: 'name-'+str(i) for i, k in |
||
83 | enumerate(_ for _ in sorted(y['tracked'].keys()) if _ not in ['tau-trajectories', 'regularization-dynamic-parameters', 'collection-passes'])})) |
||
84 | |||
85 | @classmethod |
||
86 | def create_from_experiment(cls, experiment): |
||
87 | """ |
||
88 | :param patm.modeling.experiment.Experiment experiment: |
||
89 | :return: |
||
90 | """ |
||
91 | return ExperimentalResults(SteadyTrackedItems.from_experiment(experiment), |
||
92 | ValueTracker.from_experiment(experiment), |
||
93 | FinalStateEntities.from_experiment(experiment), |
||
94 | [x.label for x in experiment.topic_model.regularizer_wrappers], |
||
95 | {k: v for k, v in zip(experiment.topic_model.regularizer_unique_types, experiment.topic_model.regularizer_names)}, |
||
96 | experiment.topic_model.definition2evaluator_name) |
||
97 | |||
98 | |||
99 | ############### PARSER ################# |
||
100 | @attr.s(slots=False) |
||
101 | class StringToDictParser(object): |
||
102 | """Parses a string (if '-' found in string they are converted to '_'; any '@' is removed) trying various regexes at runtime and returns the one capturing the most information (this is simply measured by the number of entities captured""" |
||
103 | regs = attr.ib(init=False, default={'score-word': r'[a-zA-Z]+', |
||
104 | 'score-sep': r'(?:-|_)', |
||
105 | 'numerical-argument': r'(?: \d+\. )? \d+', |
||
106 | 'modality-argument': r'[a-zA-Z]', |
||
107 | 'modality-argum1ent': r'@?[a-zA-Z]c?'}, converter=lambda x: dict(x, **{'score-type': r'{score-word}(?:{score-sep}{score-word})*'.format(**x)}), cmp=True, repr=True) |
||
108 | normalizer = attr.ib(init=False, default={ |
||
109 | 'kernel': 'topic-kernel', |
||
110 | 'top': 'top-tokens', |
||
111 | 'btr': 'background-tokens-ratio' |
||
112 | }, cmp=False, repr=True) |
||
113 | |||
114 | def __call__(self, *args, **kwargs): |
||
115 | self.string = args[0] |
||
116 | self.design = kwargs.get('design', [ |
||
117 | r'({score-type}) {score-sep} ({numerical-argument})', |
||
118 | r'({score-type}) {score-sep} @?({modality-argument})c?$', |
||
119 | r'({score-type}) ({numerical-argument})', |
||
120 | r'({score-type})']) |
||
121 | if kwargs.get('debug', False): |
||
122 | dd = [] |
||
123 | for d in self.design: |
||
124 | dd.append(self.search_debug(d, args[0])) |
||
125 | if kwargs.get('encode', False): |
||
126 | return self.encode(max(dd, key=lambda x: len(x))) |
||
127 | return max(dd, key=lambda x: len(x)) |
||
128 | if kwargs.get('encode', False): |
||
129 | return self.encode(max([self.search_n_dict(r, args[0]) for r in self.design], |
||
130 | key=lambda x: len(x))) |
||
131 | return max([self.search_n_dict(r, args[0]) for r in self.design], key=lambda x: len(x)) |
||
132 | |||
133 | def search_n_dict(self, design_line, string): |
||
134 | return OrderedDict([(k, v) for k, v in zip(self._entities(design_line), list(getattr(re.compile(design_line.format(**self.regs), re.X).match(string), 'groups', lambda: len(self._entities(design_line)) * [''])())) if v]) |
||
135 | |||
136 | def search_debug(self, design_line, string): |
||
137 | reg = re.compile(design_line.format(**self.regs), re.X) |
||
138 | res = reg.search(string) |
||
139 | if res: |
||
140 | ls = res.groups() |
||
141 | else: |
||
142 | ls = len(self._entities(design_line))*[''] |
||
143 | return OrderedDict([(k, v) for k , v in zip(self._entities(design_line), list(ls)) if v]) |
||
144 | |||
145 | def encode(self, ord_d): |
||
146 | # try: |
||
147 | # |
||
148 | # except KeyError: |
||
149 | # raise KeyError("String '{}' could not be parsed. Dict {}".format(self.string, ord_d)) |
||
150 | ord_d['score-type'] = self.normalizer.get(ord_d['score-type'].replace('_', '-'), ord_d['score-type'].replace('_', '-')) |
||
151 | if 'modality-argument' in ord_d: |
||
152 | ord_d['modality-argument'] = '@{}c'.format(ord_d['modality-argument'].lower()) |
||
153 | if 'numerical-argument' in ord_d: |
||
154 | ord_d['numerical-argument'] = self._norm_numerical(ord_d['score-type'], ord_d['numerical-argument']) |
||
155 | # if ord_d['score-type'] in ['topic_kernel', 'background_tokens_ratio']: |
||
156 | # # integer = int(ord_d['numerical-argument']) |
||
157 | # a_float = float(ord_d['numerical-argument']) |
||
158 | # if int(a_float) == a_float: |
||
159 | # ord_d['numerical-argument'] = ord_d['numerical-argument'][:2] |
||
160 | # else: |
||
161 | # ord_d['numerical-argument'] = '{:.2f}'.format(a_float)[2:] |
||
162 | return ord_d |
||
163 | # |
||
164 | # integer = int(ord_d['numerical-argument']) |
||
165 | # a_float = float(ord_d['numerical-argument']) |
||
166 | # if integer == a_float: |
||
167 | # |
||
168 | # return key, str(int(value)) |
||
169 | # return key, '{:.2f}'.format(value) |
||
170 | # ord_d['numerical-argument'] = ord_d['numerical-argument'].replace('@', '').lower() |
||
171 | # return key, value.replace('@', '').lower() |
||
172 | # if ord_d['score-type'] in ['top_kernel', 'background_tokens_ratio']: |
||
173 | # |
||
174 | # if key == 'score-type': |
||
175 | # return key, value.replace('-', '_') |
||
176 | # if key == 'modality-argument': |
||
177 | # return key, value.replace('@', '').lower() |
||
178 | # if key == 'numerical-argument': |
||
179 | # integer = int(value) |
||
180 | # a_float = float(value) |
||
181 | # if integer == a_float: |
||
182 | # return key, str(int(value)) |
||
183 | # return key, '{:.2f}'.format(value) |
||
184 | # |
||
185 | # return key, value |
||
186 | |||
187 | def _norm_numerical(self, score_type, value): |
||
188 | if score_type in ['topic-kernel', 'background-tokens-ratio']: |
||
189 | a_float = float(value) |
||
190 | if int(a_float) == a_float: |
||
191 | value = '0.' + str(int(a_float)) |
||
192 | # return value[:2] # keep fist 2 characters only as the real part |
||
193 | return '{:.2f}'.format(float(value)) |
||
194 | return value |
||
195 | |||
196 | def _entities(self, design_line): |
||
197 | return self._post_process(re.findall(r''.join([r'\({', r'(?:{score-type}|{numerical-argument}|{modality-argument})'.format(**self.regs), r'}\)']), design_line)) |
||
198 | |||
199 | def _post_process(self, entities): |
||
200 | return [x[2:-2] for x in entities] |
||
201 | |||
202 | |||
203 | ################### VALUE TRACKER ################### |
||
204 | |||
205 | @attr.s(cmp=True, repr=True, str=True) |
||
206 | class ValueTracker(object): |
||
207 | _tracked = attr.ib(init=True, cmp=True, repr=True) |
||
208 | parser = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False) |
||
209 | |||
210 | def __attrs_post_init__(self): |
||
211 | self.scores = {} |
||
212 | self._rest = {} |
||
213 | |||
214 | for tracked_definition, v in list(self._tracked.items()): # assumes maximum depth is 2 |
||
215 | |||
216 | d = self.parser(tracked_definition, encode=True, debug=False) |
||
217 | key = '-'.join(d.values()) |
||
218 | # print("DEF: {}, d: {}, key: {}".format(tracked_definition, d, key)) |
||
219 | try: |
||
220 | tracked_type = d['score-type'] |
||
221 | except KeyError: |
||
222 | raise KeyError("String '{}' was not parsed successfully. d = {}".format(tracked_definition, d)) |
||
223 | if tracked_type == 'topic-kernel': |
||
224 | self.scores[tracked_definition] = TrackedKernel(*v) |
||
225 | elif tracked_type == 'top-tokens': |
||
226 | self.scores[key] = TrackedTopTokens(*v) |
||
227 | elif tracked_type in ['perplexity', 'sparsity-theta', 'background-tokens-ratio']: |
||
228 | self.scores[key] = TrackedEntity(tracked_type, v) |
||
229 | elif tracked_type in ['sparsity-phi']: |
||
230 | self.scores[key] = TrackedEntity(tracked_type, v) |
||
231 | elif tracked_type == 'tau-trajectories': |
||
232 | self._rest[key] = TrackedTrajectories(v) |
||
233 | elif tracked_type == 'regularization-dynamic-parameters': |
||
234 | self._rest[key] = TrackedEvolvingRegParams(v) |
||
235 | else: |
||
236 | self._rest[key] = TrackedEntity(tracked_type, v) |
||
237 | |||
238 | @classmethod |
||
239 | def from_dict(cls, data): |
||
240 | tracked = data['tracked'] |
||
241 | d = {'perplexity': tracked['perplexity'], |
||
242 | 'sparsity-theta': tracked['sparsity-theta'], |
||
243 | 'collection-passes': tracked['collection-passes'], |
||
244 | 'tau-trajectories': tracked['tau-trajectories'], |
||
245 | } |
||
246 | if 'regularization-dynamic-parameters' in tracked: |
||
247 | d['regularization-dynamic-parameters'] = tracked['regularization-dynamic-parameters'] |
||
248 | else: |
||
249 | logger.info("Did not find 'regularization-dynamic-parameters' in tracked values. Probably, reading from legacy formatted object") |
||
250 | return ValueTracker(reduce(lambda x, y: dict(x, **y), [ |
||
251 | d, |
||
252 | {key: v for key, v in list(tracked.items()) if key.startswith('background-tokens-ratio')}, |
||
253 | {'topic-kernel-' + key: [[value['avg_coh'], value['avg_con'], value['avg_pur'], value['size']], |
||
254 | {t_name: t_data for t_name, t_data in list(value['topics'].items())}] |
||
255 | for key, value in list(tracked['topic-kernel'].items())}, |
||
256 | {'top-tokens-' + key: [value['avg_coh'], {t_name: t_data for t_name, t_data in list(value['topics'].items())}] |
||
257 | for key, value in list(tracked['top-tokens'].items())}, |
||
258 | {key: v for key, v in list(tracked.items()) if key.startswith('sparsity-phi-@')} |
||
259 | ])) |
||
260 | |||
261 | # except KeyError as e: |
||
262 | # raise TypeError("Error {}, all: [{}], scalars: [{}], tracked: [{}], final: [{}]".format(e, |
||
263 | # ', '.join(sorted(data.keys())), |
||
264 | # ', '.join(sorted(data['scalars'].keys())), |
||
265 | # ', '.join(sorted(data['tracked'].keys())), |
||
266 | # ', '.join(sorted(data['final'].keys())), |
||
267 | # )) |
||
268 | |||
269 | @classmethod |
||
270 | def from_experiment(cls, experiment): |
||
271 | return ValueTracker(reduce(lambda x, y: dict(x, **y), [ |
||
272 | {'perplexity': experiment.trackables['perplexity'], |
||
273 | 'sparsity-theta': experiment.trackables['sparsity-theta'], |
||
274 | 'collection-passes': experiment.collection_passes, |
||
275 | 'tau-trajectories': {matrix_name: experiment.regularizers_dynamic_parameters.get('sparse-' + matrix_name, {}).get('tau', []) for matrix_name in ['theta', 'phi']}, |
||
276 | 'regularization-dynamic-parameters': experiment.regularizers_dynamic_parameters}, |
||
277 | {key: v for key, v in list(experiment.trackables.items()) if key.startswith('background-tokens-ratio')}, |
||
278 | {kernel_definition: [[value[0], value[1], value[2], value[3]], value[4]] for kernel_definition, value in |
||
279 | list(experiment.trackables.items()) if kernel_definition.startswith('topic-kernel')}, |
||
280 | {top_tokens_definition: [value[0], value[1]] for top_tokens_definition, value in |
||
281 | list(experiment.trackables.items()) if top_tokens_definition.startswith('top-tokens-')}, |
||
282 | {key: v for key, v in list(experiment.trackables.items()) if key.startswith('sparsity-phi-@')}, |
||
283 | ])) |
||
284 | |||
285 | def __dir__(self): |
||
286 | return sorted(list(self.scores.keys()) + list(self._rest.keys())) |
||
287 | |||
288 | @property |
||
289 | def top_tokens_cardinalities(self): |
||
290 | return sorted(int(_.split('-')[-1]) for _ in list(self.scores.keys()) if _.startswith('top-tokens')) |
||
291 | |||
292 | @property |
||
293 | def kernel_thresholds(self): |
||
294 | """ |
||
295 | :return: list of strings eg ['0.60', '0.80', '0.25'] |
||
296 | """ |
||
297 | return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('topic-kernel')) |
||
298 | |||
299 | @property |
||
300 | def modalities_initials(self): |
||
301 | return sorted(_.split('-')[-1][1] for _ in list(self.scores.keys()) if _.startswith('sparsity-phi')) |
||
302 | |||
303 | @property |
||
304 | def tracked_entity_names(self): |
||
305 | return sorted(list(self.scores.keys()) + list(self._rest.keys())) |
||
306 | |||
307 | @property |
||
308 | def background_tokens_thresholds(self): |
||
309 | return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('background-tokens-ratio')) |
||
310 | |||
311 | @property |
||
312 | def tau_trajectory_matrices_names(self): |
||
313 | try: |
||
314 | return self._rest['tau-trajectories'].matrices_names |
||
315 | except KeyError: |
||
316 | raise KeyError("Key 'tau-trajectories' was not found in scores [{}]. ALL: [{}]".format( |
||
317 | ', '.join(sorted(self.scores.keys())), ', '.join(self.tracked_entity_names))) |
||
318 | |||
319 | View Code Duplication | def __getitem__(self, item): |
|
|
|||
320 | d = self.parser(item, encode=True, debug=False) |
||
321 | key = '-'.join(d.values()) |
||
322 | if key in self.scores: |
||
323 | return self.scores[key] |
||
324 | elif key in self._rest: |
||
325 | return self._rest[key] |
||
326 | raise KeyError( |
||
327 | "Requested item '{}', converted to '{}' but it was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format( |
||
328 | item, key, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys())))) |
||
329 | |||
330 | View Code Duplication | def __getattr__(self, item): |
|
331 | d = self.parser(item, encode=True, debug=False) |
||
332 | key = '-'.join(d.values()) |
||
333 | if key in self.scores: |
||
334 | return self.scores[key] |
||
335 | elif key in self._rest: |
||
336 | return self._rest[key] |
||
337 | raise AttributeError( |
||
338 | "Requested item '{}', converted to '{}' after parsed as {}. It was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format( |
||
339 | item, key, d, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys())))) |
||
340 | |||
341 | ############################################################## |
||
342 | |||
343 | |||
344 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
345 | class TrackedEvolvingRegParams(object): |
||
346 | """Holds regularizers_parameters data which is a dictionary: keys should be the unique regularizer definitions |
||
347 | (as in train.cfg; eg 'label-regularization-phi-dom-cls'). Keys should map to dictionaries, which map strings to lists. |
||
348 | These dictionaries keys should correspond to one of the supported dynamic parameters (eg 'tau', 'gamma') and each |
||
349 | list should have length equal to the number of collection iterations and hold the evolution/trajectory of the parameter values |
||
350 | """ |
||
351 | _evolved = attr.ib(init=True, converter=lambda regularizers_params: {reg: {param: TrackedEntity(param, values_list) for param, values_list in reg_params.items()} for |
||
352 | reg, reg_params in regularizers_params.items()}, cmp=True) |
||
353 | regularizers_definitions = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._evolved.keys()), takes_self=True), cmp=False, repr=True) |
||
354 | |||
355 | def __iter__(self): |
||
356 | return ((k, v) for k, v in self._evolved.items()) |
||
357 | |||
358 | def __getattr__(self, item): |
||
359 | return self._evolved[item] |
||
360 | |||
361 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
362 | class TrackedKernel(object): |
||
363 | """data = [avg_coherence_list, avg_contrast_list, avg_purity_list, sizes_list, topic_name2elements_hash]""" |
||
364 | average = attr.ib(init=True, cmp=True, converter=lambda x: KernelSubGroup(x[0], x[1], x[2], x[3])) |
||
365 | _topics_data = attr.ib(converter=lambda topic_name2elements_hash: {key: KernelSubGroup(val['coherence'], val['contrast'], val['purity'], key) for key, val in list(topic_name2elements_hash.items())}, init=True, cmp=True) |
||
366 | topics = attr.ib(default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True)) |
||
367 | |||
368 | def __getattr__(self, item): |
||
369 | if item in self._topics_data: |
||
370 | return self._topics_data[item] |
||
371 | raise AttributeError("Topic '{}' is not registered as tracked".format(item)) |
||
372 | |||
373 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
374 | class TrackedTopTokens(object): |
||
375 | average_coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('average_coherence', x), cmp=True, repr=True) |
||
376 | _topics_data = attr.ib(init=True, converter=lambda topic_name2coherence: {key: TrackedEntity(key, val) for key, val in list(topic_name2coherence.items())}, cmp=True, repr=True) |
||
377 | topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True), cmp=False, repr=True) |
||
378 | |||
379 | def __getattr__(self, item): |
||
380 | if item in self._topics_data: |
||
381 | return self._topics_data[item] |
||
382 | raise AttributeError("Topic '{}' is not registered as tracked".format(item)) |
||
383 | |||
384 | |||
385 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
386 | class KernelSubGroup(object): |
||
387 | """Containes averages over topics for metrics computed on lexical kernels defined with a specific threshold [coherence, contrast, purity] and the average kernel size""" |
||
388 | coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('coherence', x), cmp=True, repr=True) |
||
389 | contrast = attr.ib(init=True, converter=lambda x: TrackedEntity('contrast', x), cmp=True, repr=True) |
||
390 | purity = attr.ib(init=True, converter=lambda x: TrackedEntity('purity', x), cmp=True, repr=True) |
||
391 | size = attr.ib(init=True, converter=lambda x: TrackedEntity('size', x), cmp=True, repr=True) |
||
392 | name = attr.ib(init=True, default='', cmp=True, repr=True) |
||
393 | |||
394 | |||
395 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
396 | class TrackedTrajectories(object): |
||
397 | _trajs = attr.ib(init=True, converter=lambda matrix_name2elements_hash: {k: TrackedEntity(k, v) for k, v in list(matrix_name2elements_hash.items())}, cmp=True, repr=True) |
||
398 | trajectories = attr.ib(init=False, default=attr.Factory(lambda self: sorted(list(self._trajs.items()), key=lambda x: x[0]), takes_self=True), cmp=False, repr=False) |
||
399 | matrices_names = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._trajs.keys()), takes_self=True), cmp=False, repr=True) |
||
400 | |||
401 | @property |
||
402 | def phi(self): |
||
403 | if 'phi' in self._trajs: |
||
404 | return self._trajs['phi'] |
||
405 | raise AttributeError |
||
406 | |||
407 | @property |
||
408 | def theta(self): |
||
409 | if 'theta' in self._trajs: |
||
410 | return self._trajs['theta'] |
||
411 | raise AttributeError |
||
412 | |||
413 | def __str__(self): |
||
414 | return str(self.matrices_names) |
||
415 | |||
416 | |||
417 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
418 | class TrackedEntity(object): |
||
419 | name = attr.ib(init=True, converter=str, cmp=True, repr=True) |
||
420 | all = attr.ib(init=True, converter=list, cmp=True, repr=True) # list of elements (values tracked) |
||
421 | |||
422 | @property |
||
423 | def last(self): |
||
424 | return self.all[-1] |
||
425 | |||
426 | def __len__(self): |
||
427 | return len(self.all) |
||
428 | |||
429 | def __getitem__(self, item): |
||
430 | if item == 'all': |
||
431 | return self.all |
||
432 | try: |
||
433 | return self.all[item] |
||
434 | except KeyError: |
||
435 | raise KeyError("self: {} type of self._elements: {}".format(self, type(self.all).__name__)) |
||
436 | |||
437 | ######### CONSTANTS during training and between subsequent fit calls |
||
438 | |||
439 | def _check_topics(self, attribute, value): |
||
440 | if len(value) != len(set(value)): |
||
441 | raise RuntimeError("Detected duplicates in input topics: [{}]".format(', '.join(str(x) for x in value))) |
||
442 | |||
443 | def _non_overlapping(self, attribute, value): |
||
444 | if any(x in value for x in self.background_topics): |
||
445 | raise RuntimeError("Detected overlapping domain topics [{}] with background topics [{}]".format(', '.join(str(x) for x in value), ', '.join(str(x) for x in self.background_topics))) |
||
446 | |||
447 | def background_n_domain_topics_soundness(instance, attribute, value): |
||
448 | if instance.nb_topics != len(instance.background_topics) + len(value): |
||
449 | raise ValueError("nb_topics should be equal to len(background_topics) + len(domain_topics). Instead, {} != {} + {}".format(instance.nb_topics, len(instance.background_topics), len(value))) |
||
450 | |||
451 | |||
452 | @attr.s(repr=True, cmp=True, str=True, slots=True) |
||
453 | class SteadyTrackedItems(object): |
||
454 | """Supportes only one tokens list for a specific btr threshold.""" |
||
455 | dir = attr.ib(init=True, converter=str, cmp=True, repr=True) # T.O.D.O. rename to 'dataset_dir' |
||
456 | model_label = attr.ib(init=True, converter=str, cmp=True, repr=True) |
||
457 | dataset_iterations = attr.ib(init=True, converter=int, cmp=True, repr=True) |
||
458 | nb_topics = attr.ib(init=True, converter=int, cmp=True, repr=True) |
||
459 | document_passes = attr.ib(init=True, converter=int, cmp=True, repr=True) |
||
460 | background_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=_check_topics) |
||
461 | domain_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=[_check_topics, _non_overlapping, background_n_domain_topics_soundness]) |
||
462 | background_tokens_threshold = attr.ib(init=True, converter=float, cmp=True, repr=True) |
||
463 | modalities = attr.ib(init=True, converter=dict, cmp=True, repr=True) |
||
464 | |||
465 | parser = StringToDictParser() |
||
466 | def __dir__(self): |
||
467 | return ['dir', 'model_label', 'dataset_iterations', 'nb_topics', 'document_passes', 'background_topics', 'domain_topics', 'background_tokens_threshold', 'modalities'] |
||
468 | # def __str__(self): |
||
469 | # return '\n'.join(['{}: {}'.format(x, getattr(self, x)) for x in dir(self)]) |
||
470 | |||
471 | @classmethod |
||
472 | def from_dict(cls, data): |
||
473 | steady = data['scalars'] |
||
474 | background_tokens_threshold = 0 # no distinction between background and "domain" tokens |
||
475 | for eval_def in data['tracked']: |
||
476 | if eval_def.startswith('background-tokens-ratio'): |
||
477 | background_tokens_threshold = max(background_tokens_threshold, float(eval_def.split('-')[-1])) |
||
478 | return SteadyTrackedItems(steady['dir'], steady['label'], data['scalars']['dataset_iterations'], steady['nb_topics'], |
||
479 | steady['document_passes'], steady['background_topics'], steady['domain_topics'], background_tokens_threshold, steady['modalities']) |
||
480 | |||
481 | @classmethod |
||
482 | def from_experiment(cls, experiment): |
||
483 | """ |
||
484 | Uses the maximum threshold found amongst the defined 'background-tokens-ratio' scores.\n |
||
485 | :param patm.modeling.experiment.Experiment experiment: |
||
486 | :return: |
||
487 | """ |
||
488 | ds = [d for d in experiment.topic_model.evaluator_definitions if d.startswith('background-tokens-ratio-')] |
||
489 | m = 0 |
||
490 | for d in (_.split('-')[-1] for _ in ds): |
||
491 | if str(d).startswith('0.'): |
||
492 | m = max(m, float(d)) |
||
493 | else: |
||
494 | m = max(m, float('0.' + str(d))) |
||
495 | |||
496 | return SteadyTrackedItems(experiment.current_root_dir, experiment.topic_model.label, experiment.dataset_iterations, |
||
497 | experiment.topic_model.nb_topics, experiment.topic_model.document_passes, experiment.topic_model.background_topics, |
||
498 | experiment.topic_model.domain_topics, m, experiment.topic_model.modalities_dictionary) |
||
499 | |||
500 | |||
501 | @attr.s(repr=True, str=True, cmp=True, slots=True) |
||
502 | class TokensList(object): |
||
503 | tokens = attr.ib(init=True, converter=list, repr=True, cmp=True) |
||
504 | def __len__(self): |
||
505 | return len(self.tokens) |
||
506 | def __getitem__(self, item): |
||
507 | return self.tokens[item] |
||
508 | def __iter__(self): |
||
509 | return iter(self.tokens) |
||
510 | def __contains__(self, item): |
||
511 | return item in self.tokens |
||
512 | |||
513 | |||
514 | def kernel_def2_kernel(kernel_def): |
||
515 | return 'kernel'+kernel_def.split('.')[-1] |
||
516 | |||
517 | def top_tokens_def2_top(top_def): |
||
518 | return 'top'+top_def.split('-')[-1] |
||
519 | |||
520 | @attr.s(repr=True, str=True, cmp=True, slots=True) |
||
521 | class FinalStateEntities(object): |
||
522 | kernel_hash = attr.ib(init=True, converter=lambda x: {kernel_def: TopicsTokens(data) for kernel_def, data in x.items()}, cmp=True, repr=False) |
||
523 | top_hash = attr.ib(init=True, converter=lambda x: {top_tokens_def: TopicsTokens(data) for top_tokens_def, data in x.items()}, cmp=True, repr=False) |
||
524 | _bg_tokens = attr.ib(init=True, converter=TokensList, cmp=True, repr=False) |
||
525 | kernel_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.kernel_hash.keys()), takes_self=True)) |
||
526 | kernels = attr.ib(init=False, default=attr.Factory(lambda self: [kernel_def2_kernel(x) for x in self.kernel_defs], takes_self=True)) |
||
527 | top_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.top_hash.keys()), takes_self=True)) |
||
528 | top = attr.ib(init=False, default=attr.Factory(lambda self: [top_tokens_def2_top(x) for x in self.top_defs], takes_self=True)) |
||
529 | background_tokens = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._bg_tokens.tokens), takes_self=True)) |
||
530 | |||
531 | parse = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False) |
||
532 | |||
533 | def __getattr__(self, item): |
||
534 | d = self.parse(item, encode=True, debug=False) |
||
535 | key = '-'.join(d.values()) |
||
536 | if key in self.kernel_hash: |
||
537 | return self.kernel_hash[key] |
||
538 | elif key in self.top_hash: |
||
539 | return self.top_hash[key] |
||
540 | raise KeyError( |
||
541 | "Requested item '{}', converted to '{}' after parsed as {}. It was not found either in kernel thresholds [{}] nor in top-toknes cardinalities [{}]".format( |
||
542 | item, key, d, ', '.join(sorted(self.kernel_hash.keys())), ', '.join(sorted(self.top_hash.keys())))) |
||
543 | |||
544 | def __str__(self): |
||
545 | return 'bg-tokens: {}\n'.format(len(self._bg_tokens)) + '\n'.join(['final state: {}\n{}'.format(y, getattr(self, y)) for y in self.top + self.kernels]) |
||
546 | |||
547 | @classmethod |
||
548 | def from_dict(cls, data): |
||
549 | return FinalStateEntities( |
||
550 | {'topic-kernel-' + threshold: tokens_hash for threshold, tokens_hash in list(data['final']['topic-kernel'].items())}, |
||
551 | {'top-tokens-' + nb_tokens: tokens_hash for nb_tokens, tokens_hash in list(data['final']['top-tokens'].items())}, |
||
552 | data['final']['background-tokens'] |
||
553 | ) |
||
554 | |||
555 | @classmethod |
||
556 | def from_experiment(cls, experiment): |
||
557 | return FinalStateEntities(experiment.final_tokens['topic-kernel'],experiment.final_tokens['top-tokens'], experiment.final_tokens['background-tokens']) |
||
558 | |||
559 | @attr.s(repr=True, cmp=True, slots=True) |
||
560 | class TopicsTokens(object): |
||
561 | _tokens = attr.ib(init=True, converter=lambda topic_name2tokens: {topic_name: TokensList(tokens_list) for topic_name, tokens_list in topic_name2tokens.items()}, cmp=True, repr=False) |
||
562 | _nb_columns = attr.ib(init=False, default=10, cmp=False, repr=False) |
||
563 | _nb_rows = attr.ib(init=False, default=attr.Factory(lambda self: int(ceil((float(len(self._tokens)) / self._nb_columns))), takes_self=True), repr=False, cmp=False) |
||
564 | topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._tokens.keys()), takes_self=True), repr=True, cmp=False) |
||
565 | __lens = attr.ib(init=False, default=attr.Factory(lambda self: {k : {'string': len(k), 'list': len(str(len(v)))} for k, v in self._tokens.items()}, takes_self=True), cmp=False, repr=False) |
||
566 | |||
567 | def __iter__(self): |
||
568 | return ((k,v) for k,v in self._tokens.items()) |
||
569 | |||
570 | def __getattr__(self, item): |
||
571 | if item in self._tokens: |
||
572 | return self._tokens[item] |
||
573 | raise AttributeError |
||
574 | |||
575 | def __str__(self): |
||
576 | return '\n'.join([self._get_row_string(x) for x in self._gen_rows()]) |
||
577 | |||
578 | def _gen_rows(self): |
||
579 | i, j = 0, 0 |
||
580 | while i<len(self._tokens): |
||
581 | _ = self._index(j) |
||
582 | j += 1 |
||
583 | i += len(_) |
||
584 | yield _ |
||
585 | |||
586 | def _index(self, row_nb): |
||
587 | return [_f for _f in [self.topics[x] if x<len(self._tokens) else None for x in range(self._nb_columns * row_nb, self._nb_columns * row_nb + self._nb_columns)] if _f] |
||
588 | |||
589 | def _get_row_string(self, row_topic_names): |
||
590 | return ' '.join(['{}{}: {}{}'.format( |
||
591 | x[1], (self._max(x[0], 'string')-len(x[1]))*' ', (self._max(x[0], 'list')-self.__lens[x[1]]['list'])*' ', len(self._tokens[x[1]])) for x in enumerate(row_topic_names)]) |
||
592 | |||
593 | def _max(self, column_index, entity): |
||
594 | return max([self.__lens[self.topics[x]][entity] for x in self._range(column_index)]) |
||
595 | |||
596 | def _range(self, column_index): |
||
597 | return list(range(column_index, self._get_last_index_in_column(column_index), self._nb_columns)) |
||
598 | |||
599 | def _get_last_index_in_column(self, column_index): |
||
600 | return self._nb_rows * self._nb_columns - self._nb_columns + column_index |
||
601 | |||
602 | |||
603 | class RoundTripEncoder(json.JSONEncoder): |
||
604 | |||
605 | def default(self, obj): |
||
606 | if isinstance(obj, TokensList): |
||
607 | return obj.tokens |
||
608 | if isinstance(obj, TopicsTokens): |
||
609 | return {topic_name: list(getattr(obj, topic_name)) for topic_name in obj.topics} |
||
610 | if isinstance(obj, FinalStateEntities): |
||
611 | return {'topic-kernel': {kernel_def.split('-')[-1]: getattr(obj, kernel_def2_kernel(kernel_def)) for kernel_def in obj.kernel_defs}, |
||
612 | 'top-tokens': {top_def.split('-')[-1]: getattr(obj, top_tokens_def2_top(top_def)) for top_def in obj.top_defs}, |
||
613 | 'background-tokens': obj.background_tokens} |
||
614 | if isinstance(obj, SteadyTrackedItems): |
||
615 | return {'dir': obj.dir, |
||
616 | 'label': obj.model_label, |
||
617 | 'dataset_iterations': obj.dataset_iterations, |
||
618 | 'nb_topics': obj.nb_topics, |
||
619 | 'document_passes': obj.document_passes, |
||
620 | 'background_topics': obj.background_topics, |
||
621 | 'domain_topics': obj.domain_topics, |
||
622 | 'modalities': obj.modalities} |
||
623 | if isinstance(obj, TrackedEntity): |
||
624 | return obj.all |
||
625 | if isinstance(obj, TrackedTrajectories): |
||
626 | return {matrix_name: tau_elements for matrix_name, tau_elements in obj.trajectories} |
||
627 | if isinstance(obj, TrackedEvolvingRegParams): |
||
628 | return dict(obj) |
||
629 | if isinstance(obj, TrackedTopTokens): |
||
630 | return {'avg_coh': obj.average_coherence, |
||
631 | 'topics': {topic_name: getattr(obj, topic_name).all for topic_name in obj.topics}} |
||
632 | if isinstance(obj, TrackedKernel): |
||
633 | return {'avg_coh': obj.average.coherence, |
||
634 | 'avg_con': obj.average.contrast, |
||
635 | 'avg_pur': obj.average.purity, |
||
636 | 'size': obj.average.size, |
||
637 | 'topics': {topic_name: {'coherence': getattr(obj, topic_name).coherence.all, |
||
638 | 'contrast': getattr(obj, topic_name).contrast.all, |
||
639 | 'purity': getattr(obj, topic_name).purity.all} for topic_name in obj.topics}} |
||
640 | if isinstance(obj, ValueTracker): |
||
641 | _ = {name: tracked_entity for name, tracked_entity in list(obj.scores.items()) if name not in ['tau-trajectories', 'regularization-dynamic-parameters'] or |
||
642 | all(name.startswith(x) for x in ['top-tokens', 'topic-kernel'])} |
||
643 | # _ = {name: tracked_entity for name, tracked_entity in list(obj._flat.items())} |
||
644 | _['top-tokens'] = {k: obj.scores['top-tokens-' + str(k)] for k in obj.top_tokens_cardinalities} |
||
645 | _['topic-kernel'] = {k: obj.scores['topic-kernel-' + str(k)] for k in obj.kernel_thresholds} |
||
646 | _['tau-trajectories'] = {k: getattr(obj.tau_trajectories, k) for k in obj.tau_trajectory_matrices_names} |
||
647 | _['collection-passes'] = obj.collection_passes, |
||
648 | _['regularization-dynamic-parameters'] = obj.regularization_dynamic_parameters |
||
649 | return _ |
||
650 | if isinstance(obj, ExperimentalResults): |
||
651 | return {'scalars': obj.scalars, |
||
652 | 'tracked': obj.tracked, |
||
653 | 'final': obj.final, |
||
654 | 'regularizers': obj.regularizers, |
||
655 | 'reg_defs': obj.reg_defs, |
||
656 | 'score_defs': obj.score_defs} |
||
657 | return super(RoundTripEncoder, self).default(obj) |
||
658 | |||
659 | |||
660 | class RoundTripDecoder(json.JSONDecoder): |
||
661 | def __init__(self, *args, **kwargs): |
||
662 | json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) |
||
663 | |||
664 | def object_hook(self, obj): |
||
665 | return obj |
||
666 |