| Total Complexity | 135 |
| Total Lines | 666 |
| Duplicated Lines | 3 % |
| Changes | 0 | ||
Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.
Common duplication problems, and corresponding solutions are:
Complex classes like results.experimental_results often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | import sys |
||
| 2 | import json |
||
| 3 | from math import ceil |
||
| 4 | import attr |
||
| 5 | from functools import reduce |
||
| 6 | |||
| 7 | from collections import OrderedDict |
||
| 8 | import re |
||
| 9 | |||
| 10 | import logging |
||
| 11 | logger = logging.getLogger(__name__) |
||
| 12 | |||
| 13 | |||
| 14 | |||
| 15 | @attr.s(cmp=True, repr=True, str=True) |
||
| 16 | class ExperimentalResults(object): |
||
| 17 | scalars = attr.ib(init=True, cmp=True, repr=True) |
||
| 18 | tracked = attr.ib(init=True, cmp=True, repr=True) |
||
| 19 | final = attr.ib(init=True, cmp=True, repr=True) |
||
| 20 | regularizers = attr.ib(init=True, converter=lambda x: sorted(x), cmp=True, repr=True) |
||
| 21 | reg_defs = attr.ib(init=True, cmp=True, repr=True) |
||
| 22 | score_defs = attr.ib(init=True, cmp=True, repr=True) |
||
| 23 | |||
| 24 | def __str__(self): |
||
| 25 | return 'Scalars:\n{}\nTracked:\n{}\nFinal:\n{}\nRegularizers: {}'.format(self.scalars, self.tracked, self.final, ', '.join(self.regularizers)) |
||
| 26 | |||
| 27 | def __eq__(self, other): |
||
| 28 | return True # [self.scalars, self.regularizers, self.reg_defs, self.score_defs] == [other.scalars, other.regularizers, other.reg_defs, other.score_defs] |
||
| 29 | |||
| 30 | @property |
||
| 31 | def tracked_kernels(self): |
||
| 32 | return [getattr(self.tracked, 'kernel'+str(x)[2:]) for x in self.tracked.kernel_thresholds] |
||
| 33 | |||
| 34 | @property |
||
| 35 | def tracked_top_tokens(self): |
||
| 36 | return [getattr(self.tracked, 'top{}'.format(x)) for x in self.tracked.top_tokens_cardinalities] |
||
| 37 | |||
| 38 | @property |
||
| 39 | def phi_sparsities(self): |
||
| 40 | return [getattr(self.tracked, 'sparsity_phi_'.format(x)) for x in self.tracked.modalities_initials] |
||
| 41 | |||
| 42 | def to_json(self, human_redable=True): |
||
| 43 | if human_redable: |
||
| 44 | indent = 2 |
||
| 45 | else: |
||
| 46 | indent = None |
||
| 47 | return json.dumps(self, cls=RoundTripEncoder, indent=indent) |
||
| 48 | |||
| 49 | def save_as_json(self, file_path, human_redable=True, debug=False): |
||
| 50 | if human_redable: |
||
| 51 | indent = 2 |
||
| 52 | else: |
||
| 53 | indent = None |
||
| 54 | import os |
||
| 55 | if not os.path.isdir(os.path.dirname(file_path)): |
||
| 56 | raise FileNotFoundError("Directory in path '{}' has not been created.".format(os.path.dirname(file_path))) |
||
| 57 | with open(file_path, 'w') as fp: |
||
| 58 | json.dump(self, fp, cls=RoundTripEncoder, indent=indent) |
||
| 59 | |||
| 60 | @classmethod |
||
| 61 | def _legacy(cls, key, data, factory): |
||
| 62 | if key not in data: |
||
| 63 | logger.info("'{}' not in experimental results dict with keys [{}]. Perhaps object is of the old format".format(key, ', '.join(data.keys()))) |
||
| 64 | return factory(key, data) |
||
| 65 | return data[key] |
||
| 66 | |||
| 67 | @classmethod |
||
| 68 | def create_from_json_file(cls, file_path): |
||
| 69 | with open(file_path, 'r') as fp: |
||
| 70 | res = json.load(fp, cls=RoundTripDecoder) |
||
| 71 | return cls.from_dict(res) |
||
| 72 | |||
| 73 | @classmethod |
||
| 74 | def from_dict(cls, data): |
||
| 75 | return ExperimentalResults(SteadyTrackedItems.from_dict(data), |
||
| 76 | ValueTracker.from_dict(data), |
||
| 77 | FinalStateEntities.from_dict(data), |
||
| 78 | data['regularizers'], |
||
| 79 | # data.get('reg_defs', {'t'+str(i): v[:v.index('|')] for i, v in enumerate(data['regularizers'])}), |
||
| 80 | cls._legacy('reg_defs', data, lambda x, y: {'type-'+str(i): v[:v.index('|')] for i, v in enumerate(y['regularizers'])}), |
||
| 81 | cls._legacy('score_defs', data, |
||
| 82 | lambda x, y: {k: 'name-'+str(i) for i, k in |
||
| 83 | enumerate(_ for _ in sorted(y['tracked'].keys()) if _ not in ['tau-trajectories', 'regularization-dynamic-parameters', 'collection-passes'])})) |
||
| 84 | |||
| 85 | @classmethod |
||
| 86 | def create_from_experiment(cls, experiment): |
||
| 87 | """ |
||
| 88 | :param patm.modeling.experiment.Experiment experiment: |
||
| 89 | :return: |
||
| 90 | """ |
||
| 91 | return ExperimentalResults(SteadyTrackedItems.from_experiment(experiment), |
||
| 92 | ValueTracker.from_experiment(experiment), |
||
| 93 | FinalStateEntities.from_experiment(experiment), |
||
| 94 | [x.label for x in experiment.topic_model.regularizer_wrappers], |
||
| 95 | {k: v for k, v in zip(experiment.topic_model.regularizer_unique_types, experiment.topic_model.regularizer_names)}, |
||
| 96 | experiment.topic_model.definition2evaluator_name) |
||
| 97 | |||
| 98 | |||
| 99 | ############### PARSER ################# |
||
| 100 | @attr.s(slots=False) |
||
| 101 | class StringToDictParser(object): |
||
| 102 | """Parses a string (if '-' found in string they are converted to '_'; any '@' is removed) trying various regexes at runtime and returns the one capturing the most information (this is simply measured by the number of entities captured""" |
||
| 103 | regs = attr.ib(init=False, default={'score-word': r'[a-zA-Z]+', |
||
| 104 | 'score-sep': r'(?:-|_)', |
||
| 105 | 'numerical-argument': r'(?: \d+\. )? \d+', |
||
| 106 | 'modality-argument': r'[a-zA-Z]', |
||
| 107 | 'modality-argum1ent': r'@?[a-zA-Z]c?'}, converter=lambda x: dict(x, **{'score-type': r'{score-word}(?:{score-sep}{score-word})*'.format(**x)}), cmp=True, repr=True) |
||
| 108 | normalizer = attr.ib(init=False, default={ |
||
| 109 | 'kernel': 'topic-kernel', |
||
| 110 | 'top': 'top-tokens', |
||
| 111 | 'btr': 'background-tokens-ratio' |
||
| 112 | }, cmp=False, repr=True) |
||
| 113 | |||
| 114 | def __call__(self, *args, **kwargs): |
||
| 115 | self.string = args[0] |
||
| 116 | self.design = kwargs.get('design', [ |
||
| 117 | r'({score-type}) {score-sep} ({numerical-argument})', |
||
| 118 | r'({score-type}) {score-sep} @?({modality-argument})c?$', |
||
| 119 | r'({score-type}) ({numerical-argument})', |
||
| 120 | r'({score-type})']) |
||
| 121 | if kwargs.get('debug', False): |
||
| 122 | dd = [] |
||
| 123 | for d in self.design: |
||
| 124 | dd.append(self.search_debug(d, args[0])) |
||
| 125 | if kwargs.get('encode', False): |
||
| 126 | return self.encode(max(dd, key=lambda x: len(x))) |
||
| 127 | return max(dd, key=lambda x: len(x)) |
||
| 128 | if kwargs.get('encode', False): |
||
| 129 | return self.encode(max([self.search_n_dict(r, args[0]) for r in self.design], |
||
| 130 | key=lambda x: len(x))) |
||
| 131 | return max([self.search_n_dict(r, args[0]) for r in self.design], key=lambda x: len(x)) |
||
| 132 | |||
| 133 | def search_n_dict(self, design_line, string): |
||
| 134 | return OrderedDict([(k, v) for k, v in zip(self._entities(design_line), list(getattr(re.compile(design_line.format(**self.regs), re.X).match(string), 'groups', lambda: len(self._entities(design_line)) * [''])())) if v]) |
||
| 135 | |||
| 136 | def search_debug(self, design_line, string): |
||
| 137 | reg = re.compile(design_line.format(**self.regs), re.X) |
||
| 138 | res = reg.search(string) |
||
| 139 | if res: |
||
| 140 | ls = res.groups() |
||
| 141 | else: |
||
| 142 | ls = len(self._entities(design_line))*[''] |
||
| 143 | return OrderedDict([(k, v) for k , v in zip(self._entities(design_line), list(ls)) if v]) |
||
| 144 | |||
| 145 | def encode(self, ord_d): |
||
| 146 | # try: |
||
| 147 | # |
||
| 148 | # except KeyError: |
||
| 149 | # raise KeyError("String '{}' could not be parsed. Dict {}".format(self.string, ord_d)) |
||
| 150 | ord_d['score-type'] = self.normalizer.get(ord_d['score-type'].replace('_', '-'), ord_d['score-type'].replace('_', '-')) |
||
| 151 | if 'modality-argument' in ord_d: |
||
| 152 | ord_d['modality-argument'] = '@{}c'.format(ord_d['modality-argument'].lower()) |
||
| 153 | if 'numerical-argument' in ord_d: |
||
| 154 | ord_d['numerical-argument'] = self._norm_numerical(ord_d['score-type'], ord_d['numerical-argument']) |
||
| 155 | # if ord_d['score-type'] in ['topic_kernel', 'background_tokens_ratio']: |
||
| 156 | # # integer = int(ord_d['numerical-argument']) |
||
| 157 | # a_float = float(ord_d['numerical-argument']) |
||
| 158 | # if int(a_float) == a_float: |
||
| 159 | # ord_d['numerical-argument'] = ord_d['numerical-argument'][:2] |
||
| 160 | # else: |
||
| 161 | # ord_d['numerical-argument'] = '{:.2f}'.format(a_float)[2:] |
||
| 162 | return ord_d |
||
| 163 | # |
||
| 164 | # integer = int(ord_d['numerical-argument']) |
||
| 165 | # a_float = float(ord_d['numerical-argument']) |
||
| 166 | # if integer == a_float: |
||
| 167 | # |
||
| 168 | # return key, str(int(value)) |
||
| 169 | # return key, '{:.2f}'.format(value) |
||
| 170 | # ord_d['numerical-argument'] = ord_d['numerical-argument'].replace('@', '').lower() |
||
| 171 | # return key, value.replace('@', '').lower() |
||
| 172 | # if ord_d['score-type'] in ['top_kernel', 'background_tokens_ratio']: |
||
| 173 | # |
||
| 174 | # if key == 'score-type': |
||
| 175 | # return key, value.replace('-', '_') |
||
| 176 | # if key == 'modality-argument': |
||
| 177 | # return key, value.replace('@', '').lower() |
||
| 178 | # if key == 'numerical-argument': |
||
| 179 | # integer = int(value) |
||
| 180 | # a_float = float(value) |
||
| 181 | # if integer == a_float: |
||
| 182 | # return key, str(int(value)) |
||
| 183 | # return key, '{:.2f}'.format(value) |
||
| 184 | # |
||
| 185 | # return key, value |
||
| 186 | |||
| 187 | def _norm_numerical(self, score_type, value): |
||
| 188 | if score_type in ['topic-kernel', 'background-tokens-ratio']: |
||
| 189 | a_float = float(value) |
||
| 190 | if int(a_float) == a_float: |
||
| 191 | value = '0.' + str(int(a_float)) |
||
| 192 | # return value[:2] # keep fist 2 characters only as the real part |
||
| 193 | return '{:.2f}'.format(float(value)) |
||
| 194 | return value |
||
| 195 | |||
| 196 | def _entities(self, design_line): |
||
| 197 | return self._post_process(re.findall(r''.join([r'\({', r'(?:{score-type}|{numerical-argument}|{modality-argument})'.format(**self.regs), r'}\)']), design_line)) |
||
| 198 | |||
| 199 | def _post_process(self, entities): |
||
| 200 | return [x[2:-2] for x in entities] |
||
| 201 | |||
| 202 | |||
| 203 | ################### VALUE TRACKER ################### |
||
| 204 | |||
| 205 | @attr.s(cmp=True, repr=True, str=True) |
||
| 206 | class ValueTracker(object): |
||
| 207 | _tracked = attr.ib(init=True, cmp=True, repr=True) |
||
| 208 | parser = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False) |
||
| 209 | |||
| 210 | def __attrs_post_init__(self): |
||
| 211 | self.scores = {} |
||
| 212 | self._rest = {} |
||
| 213 | |||
| 214 | for tracked_definition, v in list(self._tracked.items()): # assumes maximum depth is 2 |
||
| 215 | |||
| 216 | d = self.parser(tracked_definition, encode=True, debug=False) |
||
| 217 | key = '-'.join(d.values()) |
||
| 218 | # print("DEF: {}, d: {}, key: {}".format(tracked_definition, d, key)) |
||
| 219 | try: |
||
| 220 | tracked_type = d['score-type'] |
||
| 221 | except KeyError: |
||
| 222 | raise KeyError("String '{}' was not parsed successfully. d = {}".format(tracked_definition, d)) |
||
| 223 | if tracked_type == 'topic-kernel': |
||
| 224 | self.scores[tracked_definition] = TrackedKernel(*v) |
||
| 225 | elif tracked_type == 'top-tokens': |
||
| 226 | self.scores[key] = TrackedTopTokens(*v) |
||
| 227 | elif tracked_type in ['perplexity', 'sparsity-theta', 'background-tokens-ratio']: |
||
| 228 | self.scores[key] = TrackedEntity(tracked_type, v) |
||
| 229 | elif tracked_type in ['sparsity-phi']: |
||
| 230 | self.scores[key] = TrackedEntity(tracked_type, v) |
||
| 231 | elif tracked_type == 'tau-trajectories': |
||
| 232 | self._rest[key] = TrackedTrajectories(v) |
||
| 233 | elif tracked_type == 'regularization-dynamic-parameters': |
||
| 234 | self._rest[key] = TrackedEvolvingRegParams(v) |
||
| 235 | else: |
||
| 236 | self._rest[key] = TrackedEntity(tracked_type, v) |
||
| 237 | |||
| 238 | @classmethod |
||
| 239 | def from_dict(cls, data): |
||
| 240 | tracked = data['tracked'] |
||
| 241 | d = {'perplexity': tracked['perplexity'], |
||
| 242 | 'sparsity-theta': tracked['sparsity-theta'], |
||
| 243 | 'collection-passes': tracked['collection-passes'], |
||
| 244 | 'tau-trajectories': tracked['tau-trajectories'], |
||
| 245 | } |
||
| 246 | if 'regularization-dynamic-parameters' in tracked: |
||
| 247 | d['regularization-dynamic-parameters'] = tracked['regularization-dynamic-parameters'] |
||
| 248 | else: |
||
| 249 | logger.info("Did not find 'regularization-dynamic-parameters' in tracked values. Probably, reading from legacy formatted object") |
||
| 250 | return ValueTracker(reduce(lambda x, y: dict(x, **y), [ |
||
| 251 | d, |
||
| 252 | {key: v for key, v in list(tracked.items()) if key.startswith('background-tokens-ratio')}, |
||
| 253 | {'topic-kernel-' + key: [[value['avg_coh'], value['avg_con'], value['avg_pur'], value['size']], |
||
| 254 | {t_name: t_data for t_name, t_data in list(value['topics'].items())}] |
||
| 255 | for key, value in list(tracked['topic-kernel'].items())}, |
||
| 256 | {'top-tokens-' + key: [value['avg_coh'], {t_name: t_data for t_name, t_data in list(value['topics'].items())}] |
||
| 257 | for key, value in list(tracked['top-tokens'].items())}, |
||
| 258 | {key: v for key, v in list(tracked.items()) if key.startswith('sparsity-phi-@')} |
||
| 259 | ])) |
||
| 260 | |||
| 261 | # except KeyError as e: |
||
| 262 | # raise TypeError("Error {}, all: [{}], scalars: [{}], tracked: [{}], final: [{}]".format(e, |
||
| 263 | # ', '.join(sorted(data.keys())), |
||
| 264 | # ', '.join(sorted(data['scalars'].keys())), |
||
| 265 | # ', '.join(sorted(data['tracked'].keys())), |
||
| 266 | # ', '.join(sorted(data['final'].keys())), |
||
| 267 | # )) |
||
| 268 | |||
| 269 | @classmethod |
||
| 270 | def from_experiment(cls, experiment): |
||
| 271 | return ValueTracker(reduce(lambda x, y: dict(x, **y), [ |
||
| 272 | {'perplexity': experiment.trackables['perplexity'], |
||
| 273 | 'sparsity-theta': experiment.trackables['sparsity-theta'], |
||
| 274 | 'collection-passes': experiment.collection_passes, |
||
| 275 | 'tau-trajectories': {matrix_name: experiment.regularizers_dynamic_parameters.get('sparse-' + matrix_name, {}).get('tau', []) for matrix_name in ['theta', 'phi']}, |
||
| 276 | 'regularization-dynamic-parameters': experiment.regularizers_dynamic_parameters}, |
||
| 277 | {key: v for key, v in list(experiment.trackables.items()) if key.startswith('background-tokens-ratio')}, |
||
| 278 | {kernel_definition: [[value[0], value[1], value[2], value[3]], value[4]] for kernel_definition, value in |
||
| 279 | list(experiment.trackables.items()) if kernel_definition.startswith('topic-kernel')}, |
||
| 280 | {top_tokens_definition: [value[0], value[1]] for top_tokens_definition, value in |
||
| 281 | list(experiment.trackables.items()) if top_tokens_definition.startswith('top-tokens-')}, |
||
| 282 | {key: v for key, v in list(experiment.trackables.items()) if key.startswith('sparsity-phi-@')}, |
||
| 283 | ])) |
||
| 284 | |||
| 285 | def __dir__(self): |
||
| 286 | return sorted(list(self.scores.keys()) + list(self._rest.keys())) |
||
| 287 | |||
| 288 | @property |
||
| 289 | def top_tokens_cardinalities(self): |
||
| 290 | return sorted(int(_.split('-')[-1]) for _ in list(self.scores.keys()) if _.startswith('top-tokens')) |
||
| 291 | |||
| 292 | @property |
||
| 293 | def kernel_thresholds(self): |
||
| 294 | """ |
||
| 295 | :return: list of strings eg ['0.60', '0.80', '0.25'] |
||
| 296 | """ |
||
| 297 | return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('topic-kernel')) |
||
| 298 | |||
| 299 | @property |
||
| 300 | def modalities_initials(self): |
||
| 301 | return sorted(_.split('-')[-1][1] for _ in list(self.scores.keys()) if _.startswith('sparsity-phi')) |
||
| 302 | |||
| 303 | @property |
||
| 304 | def tracked_entity_names(self): |
||
| 305 | return sorted(list(self.scores.keys()) + list(self._rest.keys())) |
||
| 306 | |||
| 307 | @property |
||
| 308 | def background_tokens_thresholds(self): |
||
| 309 | return sorted(_.split('-')[-1] for _ in list(self.scores.keys()) if _.startswith('background-tokens-ratio')) |
||
| 310 | |||
| 311 | @property |
||
| 312 | def tau_trajectory_matrices_names(self): |
||
| 313 | try: |
||
| 314 | return self._rest['tau-trajectories'].matrices_names |
||
| 315 | except KeyError: |
||
| 316 | raise KeyError("Key 'tau-trajectories' was not found in scores [{}]. ALL: [{}]".format( |
||
| 317 | ', '.join(sorted(self.scores.keys())), ', '.join(self.tracked_entity_names))) |
||
| 318 | |||
| 319 | View Code Duplication | def __getitem__(self, item): |
|
|
|
|||
| 320 | d = self.parser(item, encode=True, debug=False) |
||
| 321 | key = '-'.join(d.values()) |
||
| 322 | if key in self.scores: |
||
| 323 | return self.scores[key] |
||
| 324 | elif key in self._rest: |
||
| 325 | return self._rest[key] |
||
| 326 | raise KeyError( |
||
| 327 | "Requested item '{}', converted to '{}' but it was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format( |
||
| 328 | item, key, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys())))) |
||
| 329 | |||
| 330 | View Code Duplication | def __getattr__(self, item): |
|
| 331 | d = self.parser(item, encode=True, debug=False) |
||
| 332 | key = '-'.join(d.values()) |
||
| 333 | if key in self.scores: |
||
| 334 | return self.scores[key] |
||
| 335 | elif key in self._rest: |
||
| 336 | return self._rest[key] |
||
| 337 | raise AttributeError( |
||
| 338 | "Requested item '{}', converted to '{}' after parsed as {}. It was not found either in ValueTracker.scores [{}] nor in ValueTracker._rest [{}]".format( |
||
| 339 | item, key, d, ', '.join(sorted(self.scores.keys())), ', '.join(sorted(self._rest.keys())))) |
||
| 340 | |||
| 341 | ############################################################## |
||
| 342 | |||
| 343 | |||
| 344 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
| 345 | class TrackedEvolvingRegParams(object): |
||
| 346 | """Holds regularizers_parameters data which is a dictionary: keys should be the unique regularizer definitions |
||
| 347 | (as in train.cfg; eg 'label-regularization-phi-dom-cls'). Keys should map to dictionaries, which map strings to lists. |
||
| 348 | These dictionaries keys should correspond to one of the supported dynamic parameters (eg 'tau', 'gamma') and each |
||
| 349 | list should have length equal to the number of collection iterations and hold the evolution/trajectory of the parameter values |
||
| 350 | """ |
||
| 351 | _evolved = attr.ib(init=True, converter=lambda regularizers_params: {reg: {param: TrackedEntity(param, values_list) for param, values_list in reg_params.items()} for |
||
| 352 | reg, reg_params in regularizers_params.items()}, cmp=True) |
||
| 353 | regularizers_definitions = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._evolved.keys()), takes_self=True), cmp=False, repr=True) |
||
| 354 | |||
| 355 | def __iter__(self): |
||
| 356 | return ((k, v) for k, v in self._evolved.items()) |
||
| 357 | |||
| 358 | def __getattr__(self, item): |
||
| 359 | return self._evolved[item] |
||
| 360 | |||
| 361 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
| 362 | class TrackedKernel(object): |
||
| 363 | """data = [avg_coherence_list, avg_contrast_list, avg_purity_list, sizes_list, topic_name2elements_hash]""" |
||
| 364 | average = attr.ib(init=True, cmp=True, converter=lambda x: KernelSubGroup(x[0], x[1], x[2], x[3])) |
||
| 365 | _topics_data = attr.ib(converter=lambda topic_name2elements_hash: {key: KernelSubGroup(val['coherence'], val['contrast'], val['purity'], key) for key, val in list(topic_name2elements_hash.items())}, init=True, cmp=True) |
||
| 366 | topics = attr.ib(default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True)) |
||
| 367 | |||
| 368 | def __getattr__(self, item): |
||
| 369 | if item in self._topics_data: |
||
| 370 | return self._topics_data[item] |
||
| 371 | raise AttributeError("Topic '{}' is not registered as tracked".format(item)) |
||
| 372 | |||
| 373 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
| 374 | class TrackedTopTokens(object): |
||
| 375 | average_coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('average_coherence', x), cmp=True, repr=True) |
||
| 376 | _topics_data = attr.ib(init=True, converter=lambda topic_name2coherence: {key: TrackedEntity(key, val) for key, val in list(topic_name2coherence.items())}, cmp=True, repr=True) |
||
| 377 | topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._topics_data.keys()), takes_self=True), cmp=False, repr=True) |
||
| 378 | |||
| 379 | def __getattr__(self, item): |
||
| 380 | if item in self._topics_data: |
||
| 381 | return self._topics_data[item] |
||
| 382 | raise AttributeError("Topic '{}' is not registered as tracked".format(item)) |
||
| 383 | |||
| 384 | |||
| 385 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
| 386 | class KernelSubGroup(object): |
||
| 387 | """Containes averages over topics for metrics computed on lexical kernels defined with a specific threshold [coherence, contrast, purity] and the average kernel size""" |
||
| 388 | coherence = attr.ib(init=True, converter=lambda x: TrackedEntity('coherence', x), cmp=True, repr=True) |
||
| 389 | contrast = attr.ib(init=True, converter=lambda x: TrackedEntity('contrast', x), cmp=True, repr=True) |
||
| 390 | purity = attr.ib(init=True, converter=lambda x: TrackedEntity('purity', x), cmp=True, repr=True) |
||
| 391 | size = attr.ib(init=True, converter=lambda x: TrackedEntity('size', x), cmp=True, repr=True) |
||
| 392 | name = attr.ib(init=True, default='', cmp=True, repr=True) |
||
| 393 | |||
| 394 | |||
| 395 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
| 396 | class TrackedTrajectories(object): |
||
| 397 | _trajs = attr.ib(init=True, converter=lambda matrix_name2elements_hash: {k: TrackedEntity(k, v) for k, v in list(matrix_name2elements_hash.items())}, cmp=True, repr=True) |
||
| 398 | trajectories = attr.ib(init=False, default=attr.Factory(lambda self: sorted(list(self._trajs.items()), key=lambda x: x[0]), takes_self=True), cmp=False, repr=False) |
||
| 399 | matrices_names = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._trajs.keys()), takes_self=True), cmp=False, repr=True) |
||
| 400 | |||
| 401 | @property |
||
| 402 | def phi(self): |
||
| 403 | if 'phi' in self._trajs: |
||
| 404 | return self._trajs['phi'] |
||
| 405 | raise AttributeError |
||
| 406 | |||
| 407 | @property |
||
| 408 | def theta(self): |
||
| 409 | if 'theta' in self._trajs: |
||
| 410 | return self._trajs['theta'] |
||
| 411 | raise AttributeError |
||
| 412 | |||
| 413 | def __str__(self): |
||
| 414 | return str(self.matrices_names) |
||
| 415 | |||
| 416 | |||
| 417 | @attr.s(cmp=True, repr=True, str=True, slots=True) |
||
| 418 | class TrackedEntity(object): |
||
| 419 | name = attr.ib(init=True, converter=str, cmp=True, repr=True) |
||
| 420 | all = attr.ib(init=True, converter=list, cmp=True, repr=True) # list of elements (values tracked) |
||
| 421 | |||
| 422 | @property |
||
| 423 | def last(self): |
||
| 424 | return self.all[-1] |
||
| 425 | |||
| 426 | def __len__(self): |
||
| 427 | return len(self.all) |
||
| 428 | |||
| 429 | def __getitem__(self, item): |
||
| 430 | if item == 'all': |
||
| 431 | return self.all |
||
| 432 | try: |
||
| 433 | return self.all[item] |
||
| 434 | except KeyError: |
||
| 435 | raise KeyError("self: {} type of self._elements: {}".format(self, type(self.all).__name__)) |
||
| 436 | |||
| 437 | ######### CONSTANTS during training and between subsequent fit calls |
||
| 438 | |||
| 439 | def _check_topics(self, attribute, value): |
||
| 440 | if len(value) != len(set(value)): |
||
| 441 | raise RuntimeError("Detected duplicates in input topics: [{}]".format(', '.join(str(x) for x in value))) |
||
| 442 | |||
| 443 | def _non_overlapping(self, attribute, value): |
||
| 444 | if any(x in value for x in self.background_topics): |
||
| 445 | raise RuntimeError("Detected overlapping domain topics [{}] with background topics [{}]".format(', '.join(str(x) for x in value), ', '.join(str(x) for x in self.background_topics))) |
||
| 446 | |||
| 447 | def background_n_domain_topics_soundness(instance, attribute, value): |
||
| 448 | if instance.nb_topics != len(instance.background_topics) + len(value): |
||
| 449 | raise ValueError("nb_topics should be equal to len(background_topics) + len(domain_topics). Instead, {} != {} + {}".format(instance.nb_topics, len(instance.background_topics), len(value))) |
||
| 450 | |||
| 451 | |||
| 452 | @attr.s(repr=True, cmp=True, str=True, slots=True) |
||
| 453 | class SteadyTrackedItems(object): |
||
| 454 | """Supportes only one tokens list for a specific btr threshold.""" |
||
| 455 | dir = attr.ib(init=True, converter=str, cmp=True, repr=True) # T.O.D.O. rename to 'dataset_dir' |
||
| 456 | model_label = attr.ib(init=True, converter=str, cmp=True, repr=True) |
||
| 457 | dataset_iterations = attr.ib(init=True, converter=int, cmp=True, repr=True) |
||
| 458 | nb_topics = attr.ib(init=True, converter=int, cmp=True, repr=True) |
||
| 459 | document_passes = attr.ib(init=True, converter=int, cmp=True, repr=True) |
||
| 460 | background_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=_check_topics) |
||
| 461 | domain_topics = attr.ib(init=True, converter=list, cmp=True, repr=True, validator=[_check_topics, _non_overlapping, background_n_domain_topics_soundness]) |
||
| 462 | background_tokens_threshold = attr.ib(init=True, converter=float, cmp=True, repr=True) |
||
| 463 | modalities = attr.ib(init=True, converter=dict, cmp=True, repr=True) |
||
| 464 | |||
| 465 | parser = StringToDictParser() |
||
| 466 | def __dir__(self): |
||
| 467 | return ['dir', 'model_label', 'dataset_iterations', 'nb_topics', 'document_passes', 'background_topics', 'domain_topics', 'background_tokens_threshold', 'modalities'] |
||
| 468 | # def __str__(self): |
||
| 469 | # return '\n'.join(['{}: {}'.format(x, getattr(self, x)) for x in dir(self)]) |
||
| 470 | |||
| 471 | @classmethod |
||
| 472 | def from_dict(cls, data): |
||
| 473 | steady = data['scalars'] |
||
| 474 | background_tokens_threshold = 0 # no distinction between background and "domain" tokens |
||
| 475 | for eval_def in data['tracked']: |
||
| 476 | if eval_def.startswith('background-tokens-ratio'): |
||
| 477 | background_tokens_threshold = max(background_tokens_threshold, float(eval_def.split('-')[-1])) |
||
| 478 | return SteadyTrackedItems(steady['dir'], steady['label'], data['scalars']['dataset_iterations'], steady['nb_topics'], |
||
| 479 | steady['document_passes'], steady['background_topics'], steady['domain_topics'], background_tokens_threshold, steady['modalities']) |
||
| 480 | |||
| 481 | @classmethod |
||
| 482 | def from_experiment(cls, experiment): |
||
| 483 | """ |
||
| 484 | Uses the maximum threshold found amongst the defined 'background-tokens-ratio' scores.\n |
||
| 485 | :param patm.modeling.experiment.Experiment experiment: |
||
| 486 | :return: |
||
| 487 | """ |
||
| 488 | ds = [d for d in experiment.topic_model.evaluator_definitions if d.startswith('background-tokens-ratio-')] |
||
| 489 | m = 0 |
||
| 490 | for d in (_.split('-')[-1] for _ in ds): |
||
| 491 | if str(d).startswith('0.'): |
||
| 492 | m = max(m, float(d)) |
||
| 493 | else: |
||
| 494 | m = max(m, float('0.' + str(d))) |
||
| 495 | |||
| 496 | return SteadyTrackedItems(experiment.current_root_dir, experiment.topic_model.label, experiment.dataset_iterations, |
||
| 497 | experiment.topic_model.nb_topics, experiment.topic_model.document_passes, experiment.topic_model.background_topics, |
||
| 498 | experiment.topic_model.domain_topics, m, experiment.topic_model.modalities_dictionary) |
||
| 499 | |||
| 500 | |||
| 501 | @attr.s(repr=True, str=True, cmp=True, slots=True) |
||
| 502 | class TokensList(object): |
||
| 503 | tokens = attr.ib(init=True, converter=list, repr=True, cmp=True) |
||
| 504 | def __len__(self): |
||
| 505 | return len(self.tokens) |
||
| 506 | def __getitem__(self, item): |
||
| 507 | return self.tokens[item] |
||
| 508 | def __iter__(self): |
||
| 509 | return iter(self.tokens) |
||
| 510 | def __contains__(self, item): |
||
| 511 | return item in self.tokens |
||
| 512 | |||
| 513 | |||
| 514 | def kernel_def2_kernel(kernel_def): |
||
| 515 | return 'kernel'+kernel_def.split('.')[-1] |
||
| 516 | |||
| 517 | def top_tokens_def2_top(top_def): |
||
| 518 | return 'top'+top_def.split('-')[-1] |
||
| 519 | |||
| 520 | @attr.s(repr=True, str=True, cmp=True, slots=True) |
||
| 521 | class FinalStateEntities(object): |
||
| 522 | kernel_hash = attr.ib(init=True, converter=lambda x: {kernel_def: TopicsTokens(data) for kernel_def, data in x.items()}, cmp=True, repr=False) |
||
| 523 | top_hash = attr.ib(init=True, converter=lambda x: {top_tokens_def: TopicsTokens(data) for top_tokens_def, data in x.items()}, cmp=True, repr=False) |
||
| 524 | _bg_tokens = attr.ib(init=True, converter=TokensList, cmp=True, repr=False) |
||
| 525 | kernel_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.kernel_hash.keys()), takes_self=True)) |
||
| 526 | kernels = attr.ib(init=False, default=attr.Factory(lambda self: [kernel_def2_kernel(x) for x in self.kernel_defs], takes_self=True)) |
||
| 527 | top_defs = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self.top_hash.keys()), takes_self=True)) |
||
| 528 | top = attr.ib(init=False, default=attr.Factory(lambda self: [top_tokens_def2_top(x) for x in self.top_defs], takes_self=True)) |
||
| 529 | background_tokens = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._bg_tokens.tokens), takes_self=True)) |
||
| 530 | |||
| 531 | parse = attr.ib(init=False, factory=StringToDictParser, cmp=False, repr=False) |
||
| 532 | |||
| 533 | def __getattr__(self, item): |
||
| 534 | d = self.parse(item, encode=True, debug=False) |
||
| 535 | key = '-'.join(d.values()) |
||
| 536 | if key in self.kernel_hash: |
||
| 537 | return self.kernel_hash[key] |
||
| 538 | elif key in self.top_hash: |
||
| 539 | return self.top_hash[key] |
||
| 540 | raise KeyError( |
||
| 541 | "Requested item '{}', converted to '{}' after parsed as {}. It was not found either in kernel thresholds [{}] nor in top-toknes cardinalities [{}]".format( |
||
| 542 | item, key, d, ', '.join(sorted(self.kernel_hash.keys())), ', '.join(sorted(self.top_hash.keys())))) |
||
| 543 | |||
| 544 | def __str__(self): |
||
| 545 | return 'bg-tokens: {}\n'.format(len(self._bg_tokens)) + '\n'.join(['final state: {}\n{}'.format(y, getattr(self, y)) for y in self.top + self.kernels]) |
||
| 546 | |||
| 547 | @classmethod |
||
| 548 | def from_dict(cls, data): |
||
| 549 | return FinalStateEntities( |
||
| 550 | {'topic-kernel-' + threshold: tokens_hash for threshold, tokens_hash in list(data['final']['topic-kernel'].items())}, |
||
| 551 | {'top-tokens-' + nb_tokens: tokens_hash for nb_tokens, tokens_hash in list(data['final']['top-tokens'].items())}, |
||
| 552 | data['final']['background-tokens'] |
||
| 553 | ) |
||
| 554 | |||
| 555 | @classmethod |
||
| 556 | def from_experiment(cls, experiment): |
||
| 557 | return FinalStateEntities(experiment.final_tokens['topic-kernel'],experiment.final_tokens['top-tokens'], experiment.final_tokens['background-tokens']) |
||
| 558 | |||
| 559 | @attr.s(repr=True, cmp=True, slots=True) |
||
| 560 | class TopicsTokens(object): |
||
| 561 | _tokens = attr.ib(init=True, converter=lambda topic_name2tokens: {topic_name: TokensList(tokens_list) for topic_name, tokens_list in topic_name2tokens.items()}, cmp=True, repr=False) |
||
| 562 | _nb_columns = attr.ib(init=False, default=10, cmp=False, repr=False) |
||
| 563 | _nb_rows = attr.ib(init=False, default=attr.Factory(lambda self: int(ceil((float(len(self._tokens)) / self._nb_columns))), takes_self=True), repr=False, cmp=False) |
||
| 564 | topics = attr.ib(init=False, default=attr.Factory(lambda self: sorted(self._tokens.keys()), takes_self=True), repr=True, cmp=False) |
||
| 565 | __lens = attr.ib(init=False, default=attr.Factory(lambda self: {k : {'string': len(k), 'list': len(str(len(v)))} for k, v in self._tokens.items()}, takes_self=True), cmp=False, repr=False) |
||
| 566 | |||
| 567 | def __iter__(self): |
||
| 568 | return ((k,v) for k,v in self._tokens.items()) |
||
| 569 | |||
| 570 | def __getattr__(self, item): |
||
| 571 | if item in self._tokens: |
||
| 572 | return self._tokens[item] |
||
| 573 | raise AttributeError |
||
| 574 | |||
| 575 | def __str__(self): |
||
| 576 | return '\n'.join([self._get_row_string(x) for x in self._gen_rows()]) |
||
| 577 | |||
| 578 | def _gen_rows(self): |
||
| 579 | i, j = 0, 0 |
||
| 580 | while i<len(self._tokens): |
||
| 581 | _ = self._index(j) |
||
| 582 | j += 1 |
||
| 583 | i += len(_) |
||
| 584 | yield _ |
||
| 585 | |||
| 586 | def _index(self, row_nb): |
||
| 587 | return [_f for _f in [self.topics[x] if x<len(self._tokens) else None for x in range(self._nb_columns * row_nb, self._nb_columns * row_nb + self._nb_columns)] if _f] |
||
| 588 | |||
| 589 | def _get_row_string(self, row_topic_names): |
||
| 590 | return ' '.join(['{}{}: {}{}'.format( |
||
| 591 | x[1], (self._max(x[0], 'string')-len(x[1]))*' ', (self._max(x[0], 'list')-self.__lens[x[1]]['list'])*' ', len(self._tokens[x[1]])) for x in enumerate(row_topic_names)]) |
||
| 592 | |||
| 593 | def _max(self, column_index, entity): |
||
| 594 | return max([self.__lens[self.topics[x]][entity] for x in self._range(column_index)]) |
||
| 595 | |||
| 596 | def _range(self, column_index): |
||
| 597 | return list(range(column_index, self._get_last_index_in_column(column_index), self._nb_columns)) |
||
| 598 | |||
| 599 | def _get_last_index_in_column(self, column_index): |
||
| 600 | return self._nb_rows * self._nb_columns - self._nb_columns + column_index |
||
| 601 | |||
| 602 | |||
| 603 | class RoundTripEncoder(json.JSONEncoder): |
||
| 604 | |||
| 605 | def default(self, obj): |
||
| 606 | if isinstance(obj, TokensList): |
||
| 607 | return obj.tokens |
||
| 608 | if isinstance(obj, TopicsTokens): |
||
| 609 | return {topic_name: list(getattr(obj, topic_name)) for topic_name in obj.topics} |
||
| 610 | if isinstance(obj, FinalStateEntities): |
||
| 611 | return {'topic-kernel': {kernel_def.split('-')[-1]: getattr(obj, kernel_def2_kernel(kernel_def)) for kernel_def in obj.kernel_defs}, |
||
| 612 | 'top-tokens': {top_def.split('-')[-1]: getattr(obj, top_tokens_def2_top(top_def)) for top_def in obj.top_defs}, |
||
| 613 | 'background-tokens': obj.background_tokens} |
||
| 614 | if isinstance(obj, SteadyTrackedItems): |
||
| 615 | return {'dir': obj.dir, |
||
| 616 | 'label': obj.model_label, |
||
| 617 | 'dataset_iterations': obj.dataset_iterations, |
||
| 618 | 'nb_topics': obj.nb_topics, |
||
| 619 | 'document_passes': obj.document_passes, |
||
| 620 | 'background_topics': obj.background_topics, |
||
| 621 | 'domain_topics': obj.domain_topics, |
||
| 622 | 'modalities': obj.modalities} |
||
| 623 | if isinstance(obj, TrackedEntity): |
||
| 624 | return obj.all |
||
| 625 | if isinstance(obj, TrackedTrajectories): |
||
| 626 | return {matrix_name: tau_elements for matrix_name, tau_elements in obj.trajectories} |
||
| 627 | if isinstance(obj, TrackedEvolvingRegParams): |
||
| 628 | return dict(obj) |
||
| 629 | if isinstance(obj, TrackedTopTokens): |
||
| 630 | return {'avg_coh': obj.average_coherence, |
||
| 631 | 'topics': {topic_name: getattr(obj, topic_name).all for topic_name in obj.topics}} |
||
| 632 | if isinstance(obj, TrackedKernel): |
||
| 633 | return {'avg_coh': obj.average.coherence, |
||
| 634 | 'avg_con': obj.average.contrast, |
||
| 635 | 'avg_pur': obj.average.purity, |
||
| 636 | 'size': obj.average.size, |
||
| 637 | 'topics': {topic_name: {'coherence': getattr(obj, topic_name).coherence.all, |
||
| 638 | 'contrast': getattr(obj, topic_name).contrast.all, |
||
| 639 | 'purity': getattr(obj, topic_name).purity.all} for topic_name in obj.topics}} |
||
| 640 | if isinstance(obj, ValueTracker): |
||
| 641 | _ = {name: tracked_entity for name, tracked_entity in list(obj.scores.items()) if name not in ['tau-trajectories', 'regularization-dynamic-parameters'] or |
||
| 642 | all(name.startswith(x) for x in ['top-tokens', 'topic-kernel'])} |
||
| 643 | # _ = {name: tracked_entity for name, tracked_entity in list(obj._flat.items())} |
||
| 644 | _['top-tokens'] = {k: obj.scores['top-tokens-' + str(k)] for k in obj.top_tokens_cardinalities} |
||
| 645 | _['topic-kernel'] = {k: obj.scores['topic-kernel-' + str(k)] for k in obj.kernel_thresholds} |
||
| 646 | _['tau-trajectories'] = {k: getattr(obj.tau_trajectories, k) for k in obj.tau_trajectory_matrices_names} |
||
| 647 | _['collection-passes'] = obj.collection_passes, |
||
| 648 | _['regularization-dynamic-parameters'] = obj.regularization_dynamic_parameters |
||
| 649 | return _ |
||
| 650 | if isinstance(obj, ExperimentalResults): |
||
| 651 | return {'scalars': obj.scalars, |
||
| 652 | 'tracked': obj.tracked, |
||
| 653 | 'final': obj.final, |
||
| 654 | 'regularizers': obj.regularizers, |
||
| 655 | 'reg_defs': obj.reg_defs, |
||
| 656 | 'score_defs': obj.score_defs} |
||
| 657 | return super(RoundTripEncoder, self).default(obj) |
||
| 658 | |||
| 659 | |||
| 660 | class RoundTripDecoder(json.JSONDecoder): |
||
| 661 | def __init__(self, *args, **kwargs): |
||
| 662 | json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs) |
||
| 663 | |||
| 664 | def object_hook(self, obj): |
||
| 665 | return obj |
||
| 666 |