patm.modeling.experiment   F
last analyzed

Complexity

Total Complexity 69

Size/Duplication

Total Lines 343
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 220
dl 0
loc 343
rs 2.88
c 0
b 0
f 0
wmc 69

27 Methods

Rating   Name   Duplication   Size   Complexity  
A Experiment.model_factory() 0 3 1
A Experiment.save_experiment() 0 12 3
A Experiment.load_experiment() 0 32 2
A DegenerationChecker._has_degenerated() 0 3 4
A Experiment.__init__() 0 20 1
A DegenerationChecker._initialize() 0 6 1
A DegenerationChecker._build_tuple() 0 20 4
A DegenerationChecker.__init__() 0 2 1
A Experiment.topic_model() 0 6 1
A Experiment.dataset_iterations() 0 3 1
A DegenerationChecker.keys() 0 3 1
B Experiment.init_empty_trackables() 0 17 6
A DegenerationChecker.get_degenerated_tuples() 0 9 1
A Experiment._get_final_tokens() 0 2 1
A DegenerationChecker._add_final() 0 3 2
A Experiment._assert_max_decimals() 0 18 3
A DegenerationChecker._get_struct() 0 2 1
D Experiment.update() 0 43 12
A Experiment._strip_parameters() 0 10 4
A DegenerationChecker.get_degeneration_info() 0 3 1
A DegenerationChecker._get_degen_keys() 0 2 2
A DegenerationChecker.__str__() 0 2 1
A Experiment.current_root_dir() 0 3 1
A DegenerationChecker.__repr__() 0 2 1
A DegenerationChecker._build_degen_info() 0 13 5
A Experiment.dictionary() 0 4 1
B Experiment._get_trackables() 0 25 5

How to fix   Complexity   

Complexity

Complex classes like patm.modeling.experiment often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from .model_factory import ModelFactory
2
from .persistence import ResultsWL, ModelWL
3
from .regularization.regularizers_factory import REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH as DYN_COEFS
4
5
6
class Experiment:
7
    """
8
    This class encapsulates experimental related activities such as tracking of various quantities during model training
9
    and also has capabilities of persisting these tracked quantities as experimental results.\n
10
    Experimental results include:
11
    - number of topics: the number of topics inffered
12
    - document passes: the inner iterations over each document. Phi matrix gets updated 'document_passes' x 'nb_docs' times during one path through the whole collection
13
    - root dir: the base directory of the colletion, on which experiments were conducted
14
    - model label: the unique identifier of the model used for the experiments
15
    - trackable "evaluation" metrics
16
    - regularization parameters
17
    - _evaluator_name2definition
18
    """
19
20
21
    MAX_DECIMALS = 2
22
    def __init__(self, dataset_dir):
23
        """
24
        Encapsulates experimentation by doing topic modeling on a dataset/'collection' in the given patm_root_dir. A 'collection' is a proccessed document collection into BoW format and possibly split into 'train' and 'test' splits.\n
25
        :param str dataset_dir: the full path to a dataset/collection specific directory
26
        :param str cooc_dict: the full path to a dataset/collection specific directory
27
        """
28
        self._dir = dataset_dir
29
        # self.cooc_dict = cooc_dict
30
        self._loaded_dictionary = None # artm.Dictionary object. Data population happens uppon artm.Artm object creation in model_factory; dictionary.load(bin_dict_path) is called there
31
        self._topic_model = None
32
        self.collection_passes = []
33
        self.trackables = None
34
        # self.train_results_handler = ResultsWL.from_experiment(self)
35
        # self.phi_matrix_handler = ModelWL.from_experiment(self)
36
        self.train_results_handler = ResultsWL(self)
37
        self.phi_matrix_handler = ModelWL(self)
38
        self.failed_top_tokens_coherence = {}
39
        self._total_passes = 0
40
        self.regularizers_dynamic_parameters = None
41
        self._last_tokens = {}
42
43
    def init_empty_trackables(self, model):
44
        self._topic_model = model
45
        self.trackables = {}
46
        self.failed_top_tokens_coherence = {}
47
        self._total_passes = 0
48
        for evaluator_definition in model.evaluator_definitions:
49
            if self._strip_parameters(evaluator_definition) in ('perplexity', 'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'):
50
                self.trackables[Experiment._assert_max_decimals(evaluator_definition)] = []
51
            elif evaluator_definition.startswith('topic-kernel-'):
52
                self.trackables[Experiment._assert_max_decimals(evaluator_definition)] = [[], [], [], [], {t_name: {'coherence': [], 'contrast': [], 'purity': [], 'size': []} for t_name in model.domain_topics}]
53
            elif evaluator_definition.startswith('top-tokens-'):
54
                self.trackables[evaluator_definition] = [[], {t_name: [] for t_name in model.domain_topics}]
55
                self.failed_top_tokens_coherence[evaluator_definition] = {t_name: [] for t_name in model.domain_topics}
56
        self.collection_passes = []
57
        if not all(x in DYN_COEFS for _, x in model.long_types_n_types):
58
            raise KeyError("One of the [{}] 'reg_type' (not unique types!) should be added in the REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH in the regularizers_factory module".format(', '.join("'{}'".format(x) for _, x in model.long_types_n_types)))
59
        self.regularizers_dynamic_parameters = {unique_type: {attr: [] for attr in DYN_COEFS[reg_type]} for unique_type, reg_type in model.long_types_n_types}
60
61
    @property
62
    def dataset_iterations(self):
63
        return sum(self.collection_passes)
64
65
    @property
66
    def model_factory(self):
67
        return ModelFactory(self.dictionary)
68
69
    @property
70
    def topic_model(self):
71
        """
72
        :rtype: patm.modeling.topic_model.TopicModel
73
        """
74
        return self._topic_model
75
76
    @staticmethod
77
    def _strip_parameters(score_definition):
78
        tokens = []
79
        for el in score_definition.split('-'):
80
            try:
81
                _ = float(el)
82
            except ValueError:
83
                if el[0] != '@':
84
                    tokens.append(el)
85
        return '-'.join(tokens)
86
87
    @classmethod
88
    def _assert_max_decimals(cls, definition):
89
        """Converts like:\n
90
        - 'topic-kernel-0.6'            -> 'topic-kernel-0.60'\n
91
        - 'topic-kernel-0.871'          -> 'topic-kernel-0.87'\n
92
        - 'background-tokens-ratio-0.3' -> 'background-tokens-ratio-0.30'
93
        """
94
        s = definition.split('-')
95
        sl = s[-1]
96
        try:
97
            _ = float(sl)
98
            if len(sl) < 4:
99
                sl = '{}{}'.format(sl, '0'*(2 + cls.MAX_DECIMALS- len(sl)))
100
            else:
101
                sl = sl[:4]
102
            return '-'.join(s[:-1] + [sl])
103
        except ValueError:
104
            return definition
105
106
    @property
107
    def dictionary(self):
108
        """This dictionary is passed to Perplexity artm scorer"""
109
        return self._loaded_dictionary
110
111
    @dictionary.setter
112
    def dictionary(self, artm_dictionary):
113
        self._loaded_dictionary = artm_dictionary
114
115
    # TODO refactor this; remove dubious exceptions
116
    def update(self, topic_model, span):
117
        self.collection_passes.append(span) # iterations performed on the train set for the current 'steady' chunk
118
119
        for unique_reg_type, reg_settings in list(topic_model.get_regs_param_dict().items()):
120
            for param_name, param_value in list(reg_settings.items()):
121
                self.regularizers_dynamic_parameters[unique_reg_type][param_name].extend([param_value] * span)
122
        for evaluator_name, evaluator_definition in zip(topic_model.evaluator_names, topic_model.evaluator_definitions):
123
            reportable_to_results = topic_model.get_evaluator(evaluator_name).evaluate(topic_model.artm_model)
124
            definition_with_max_decimals = Experiment._assert_max_decimals(evaluator_definition)
125
126
            if self._strip_parameters(evaluator_definition) in ('perplexity', 'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'):
127
                self.trackables[definition_with_max_decimals].extend(reportable_to_results['value'][-span:])
128
            elif evaluator_definition.startswith('topic-kernel-'):
129
                self.trackables[definition_with_max_decimals][0].extend(reportable_to_results['average_coherence'][-span:])
130
                self.trackables[definition_with_max_decimals][1].extend(reportable_to_results['average_contrast'][-span:])
131
                self.trackables[definition_with_max_decimals][2].extend(reportable_to_results['average_purity'][-span:])
132
                self.trackables[definition_with_max_decimals][3].extend(reportable_to_results['average_size'][-span:])
133
                for topic_name, topic_metrics in list(self.trackables[definition_with_max_decimals][4].items()):
134
                    topic_metrics['coherence'].extend([x[topic_name] for x in reportable_to_results['coherence'][-span:]])
135
                    topic_metrics['contrast'].extend([x[topic_name] for x in reportable_to_results['contrast'][-span:]])
136
                    topic_metrics['purity'].extend([x[topic_name] for x in reportable_to_results['purity'][-span:]])
137
                    topic_metrics['size'].extend([x[topic_name] for x in reportable_to_results['size'][-span:]])
138
            elif evaluator_definition.startswith('top-tokens-'):
139
                self.trackables[evaluator_definition][0].extend(reportable_to_results['average_coherence'][-span:])
140
141
                assert 'coherence' in reportable_to_results
142
                for topic_name, topic_metrics in list(self.trackables[evaluator_definition][1].items()):
143
                    try:
144
                        topic_metrics.extend([x[topic_name] for x in reportable_to_results['coherence'][-span:]])
145
                    except KeyError:
146
                        if len(self.failed_top_tokens_coherence[evaluator_definition][topic_name]) == 0:
147
                            self.failed_top_tokens_coherence[evaluator_definition][topic_name].append((self._total_passes, span))
148
                        else:
149
                            if span + self.failed_top_tokens_coherence[evaluator_definition][topic_name][-1][0] == self._total_passes:
150
                                perv_tuple = self.failed_top_tokens_coherence[evaluator_definition][topic_name][-1]
151
                                self.failed_top_tokens_coherence[evaluator_definition][topic_name][-1] = (self._total_passes, perv_tuple[1]+span)
152
                self._total_passes += span
153
        self.final_tokens = {
154
            'topic-kernel': {eval_def: self._get_final_tokens(eval_def) for eval_def in
155
                             self.topic_model.evaluator_definitions if eval_def.startswith('topic-kernel-')},
156
            'top-tokens': {eval_def: self._get_final_tokens(eval_def) for eval_def in
157
                           self.topic_model.evaluator_definitions if eval_def.startswith('top-tokens-')},
158
            'background-tokens': self.topic_model.background_tokens
159
        }
160
161
    @property
162
    def current_root_dir(self):
163
        return self._dir
164
165
    def save_experiment(self, save_phi=True):
166
        """Dumps the dictionary-type accumulated experimental results with the given file name. The file is saved in the directory specified by the latest train specifications (TrainSpecs)"""
167
        if not self.collection_passes:
168
            raise DidNotReceiveTrainSignalException('Model probably hasn\'t been fitted since len(self.collection_passes) = {}'.format(len(self.collection_passes)))
169
        # asserts that the number of observations recorded per tracked metric variables is equal to the number of "collections pass"; training iterations over the document dataset
170
        # artm.ARTM.scores satisfy this
171
        # print 'Saving model \'{}\', train set iterations: {}'.format(self.topic_model.label, self.collection_passes)
172
        # assert all(map(lambda x: len(x) == sum(self.collection_passes), [values_list for eval2scoresdict in self.trackables.values() for values_list in eval2scoresdict.values()]))
173
        # the above will fail when metrics outside the artm library start to get tracked, because these will be able to only capture the last state of the metric trajectory due to fitting by "chunks"
174
        self.train_results_handler.save(self._topic_model.label)
175
        if save_phi:
176
            self.phi_matrix_handler.save(self._topic_model.label)
177
178
    def load_experiment(self, model_label):
179
        """
180
        Given a unigue model label, restores the state of the experiment from disk. Loads all tracked values of the experimental results
181
        and the state of the TopicModel inferred so far: namely the phi p_wt matrix.
182
        In details loads settings:\n
183
        - doc collection fit iteration steady_chunks\n
184
        - eval metrics/measures trackes per iteration\n
185
        - regularization parameters\n
186
        - document passes\n
187
        :param str model_label: a unigue identifier of a topic model
188
        :return: the latest train specification used in the experiment
189
        :rtype: patm.modeling.topic_model.TopicModel
190
        """
191
        results = self.train_results_handler.load(model_label)
192
        self._topic_model = self.phi_matrix_handler.load(model_label, results)
193
        # self._topic_model.scores = {v.name: topic_model.artm_model.score_tracker[v.name] for v in topic_model._definition2evaluator.values()}
194
        self.failed_top_tokens_coherence = {}
195
        self.collection_passes = [item for sublist in results.tracked.collection_passes for item in sublist]
196
        # [x for x in ] [sublist for sublist in results.tracked.collection_passes in x for x in sublist for sublist in results.tracked.collection_passes]
197
        self._total_passes = sum(sum(_) for _ in results.tracked.collection_passes)
198
        try:
199
            self.regularizers_dynamic_parameters = dict(results.tracked.regularization_dynamic_parameters)
200
        except KeyError as e:
201
            print(e)
202
            raise RuntimeError("Tracked: {}, dynamic reg params: {}".format(results.tracked, results.tracked.regularization_dynamic_parameters))
203
        self.trackables = self._get_trackables(results)
204
        self.final_tokens = {
205
            'topic-kernel': {kernel_def: {t_name: list(tokens) for t_name, tokens in topics_tokens} for kernel_def, topics_tokens in list(results.final.kernel_hash.items())},
206
            'top-tokens': {top_tokens_def: {t_name: list(tokens) for t_name, tokens in topics_tokens} for top_tokens_def, topics_tokens in list(results.final.top_hash.items())},
207
            'background-tokens': list(results.final.background_tokens),
208
        }
209
        return self._topic_model
210
211
    def _get_final_tokens(self, evaluation_definition):
212
        return self.topic_model.artm_model.score_tracker[self.topic_model.definition2evaluator_name[evaluation_definition]].tokens[-1]
213
214
215
    def _get_trackables(self, results):
216
        """
217
        :param results.experimental_results.ExperimentalResults results:
218
        :return:
219
        :rtype: dict
220
        """
221
        trackables = {}
222
        for k in (_.replace('_', '-') for _ in dir(results.tracked) if _ not in ('tau_trajectories', 'regularization_dynamic_parameters')):
223
            if self._strip_parameters(k) in ('perplexity', 'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'):
224
                trackables[Experiment._assert_max_decimals(k)] = results.tracked[k].all
225
            elif k.startswith('topic-kernel-'):
226
                trackables[Experiment._assert_max_decimals(k)] = [
227
                    results.tracked[k].average.coherence.all,
228
                    results.tracked[k].average.contrast.all,
229
                    results.tracked[k].average.purity.all,
230
                    results.tracked[k].average.size.all,
231
                    {t_name: {'coherence': getattr(results.tracked[k], t_name).coherence.all,
232
                              'contrast': getattr(results.tracked[k], t_name).contrast.all,
233
                              'purity': getattr(results.tracked[k], t_name).purity.all,
234
                              'size': getattr(results.tracked[k], t_name).size.all} for t_name in results.scalars.domain_topics}
235
                ]
236
            elif k.startswith('top-tokens-'):
237
                trackables[k] = [results.tracked[k].average_coherence.all,
238
                                      {t_name: getattr(results.tracked[k], t_name).all for t_name in results.scalars.domain_topics}]
239
        return trackables
240
241
242
class DegenerationChecker(object):
243
    def __init__(self, reference_keys):
244
        self._keys = sorted(reference_keys)
245
246
    def _initialize(self):
247
        self._iter = 0
248
        self._degen_info = {key: [] for key in self._keys}
249
        self._building = {k: False for k in self._keys}
250
        self._li = {k: 0 for k in self._keys}
251
        self._ri = {k: 0 for k in self._keys}
252
253
    def __str__(self):
254
        return '[{}]'.format(', '.join(self._keys))
255
256
    def __repr__(self):
257
        return "{}({})".format(self.__class__.__name__, str(self))
258
259
    @property
260
    def keys(self):
261
        return self.keys
262
263
    @keys.setter
264
    def keys(self, keys):
265
        self._keys = sorted(keys)
266
267
    def get_degeneration_info(self, dict_list):
268
        self._initialize()
269
        return dict([(x, self.get_degenerated_tuples(x, self._get_struct(dict_list))) for x in self._keys])
270
271
    # def get_degen_info_0(self, dict_list):
272
    #     self.build(dict_list)
273
    #     return self._degen_info
274
    #
275
    # def build(self, dict_list):
276
    #     self._prepare_storing()
277
    #     for k in self._keys:
278
    #         self._build_degen_info(k, self._get_struct(dict_list))
279
    #         self._add_final(k, self._degen_info[k])
280
281
    def get_degenerated_tuples(self, key, struct):
282
        """
283
        :param str key:
284
        :param list struct: output of self._get_struct(dict_list)
285
        :return: a list of tuples indicating starting and finishing train iterations when the information has been degenerated (missing)
286
        """
287
        _ = [_f for _f in [self._build_tuple(key, x[1], x[0]) for x in enumerate(struct)] if _f]
288
        self._add_final(key, _)
289
        return _
290
291
    def _add_final(self, key, info):
292
        if self._building[key]:
293
            info.append((self._li[key], self._ri[key]))
294
295
    def _build_tuple(self, key, struct_element, iter_count):
296
        """
297
        :param str key:
298
        :param list struct_element:
299
        :param int iter_count:
300
        :return:
301
        """
302
        r = None
303
        if key in struct_element:  # if has lost its information (has degenerated) for iteration/collection_pass i
304
            if self._building[key]:
305
                self._ri[key] += 1
306
            else:
307
                self._li[key] = iter_count
308
                self._ri[key] = iter_count + 1
309
                self._building[key] = True
310
        else:
311
            if self._building[key]:
312
                r = (self._li[key], self._ri[key])
313
                self._building[key] = False
314
        return r
315
316
    def _build_degen_info(self, key, struct):
317
        for i, el in enumerate(struct):
318
            if key in el:  # if has lost its information (has degenerated) for iteration/collection_pass i
319
                if self._building[key]:
320
                    self._ri[key] += 1
321
                else:
322
                    self._li[key] = i
323
                    self._ri[key] = i + 1
324
                    self._building[key] = True
325
            else:
326
                if self._building[key]:
327
                    self._degen_info[key].append((self._li[key], self._ri[key]))
328
                    self._building[key] = False
329
330
    def _get_struct(self, dict_list):
331
        return [self._get_degen_keys(self._keys, x) for x in dict_list]
332
333
    def _get_degen_keys(self, key_list, a_dict):
334
        return [_f for _f in [x if self._has_degenerated(x, a_dict) else None for x in key_list] if _f]
335
336
    def _has_degenerated(self, key, a_dict):
337
        if key not in a_dict or len(a_dict[key]) == 0 or not a_dict[key]: return True
338
        return False
339
340
341
class EvaluationOutputLoadingException(Exception): pass
342
class DidNotReceiveTrainSignalException(Exception): pass
343