TrainSpecs.to_taus_slice()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import pprint
2
3
import warnings
4
from collections import Counter
5
6
from .regularization.regularizers_factory import REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH
7
8
9
class TopicModel(object):
10
    """
11
    An instance of this class encapsulates the behaviour of a topic_model.
12
    - LIMITATION:
13
        Does not support dynamic changing of the number of topics between different training cycles.
14
    """
15
    def __init__(self, label, artm_model, evaluators):
16
        """
17
        Creates a topic model object. The model is ready to add regularizers to it.\n
18
        :param str label: a unique identifier for the model; model name
19
        :param artm.ARTM artm_model: a reference to an artm model object
20
        :param dict evaluators: mapping from evaluator-definitions strings to patm.evaluation.base_evaluator.ArtmEvaluator objects
21
         ; i.e. {'perplexity': obj0, 'sparsity-phi-\@dc': obj1, 'sparsity-phi-\@ic': obj2, 'topic-kernel-0.6': obj3,
22
         'topic-kernel-0.8': obj4, 'top-tokens-10': obj5}
23
        """
24
        self.label = label
25
        self.artm_model = artm_model
26
27
        self._definition2evaluator = evaluators
28
29
        self.definition2evaluator_name = {k: v.name for k, v in evaluators.items()} # {'sparsity-phi': 'spd_p', 'top-tokens-10': 'tt10'}
30
        self._evaluator_name2definition = {v.name: k for k, v in evaluators.items()} # {'sp_p': 'sparsity-phi', 'tt10': 'top-tokens-10'}
31
        self._reg_longtype2name = {}  # ie {'smooth-theta': 'smb_t', 'sparse-phi': 'spd_p'}
32
        self._reg_name2wrapper = {}
33
34
    def add_regularizer_wrapper(self, reg_wrapper):
35
        self._reg_name2wrapper[reg_wrapper.name] = reg_wrapper
36
        self._reg_longtype2name[reg_wrapper.long_type] = reg_wrapper.name
37
        self.artm_model.regularizers.add(reg_wrapper.artm_regularizer)
38
39
    def add_regularizer_wrappers(self, reg_wrappers):
40
        for _ in reg_wrappers:
41
            self.add_regularizer_wrapper(_)
42
43
    @property
44
    def pformat_regularizers(self):
45
        def _filter(reg_name):
46
            d = self._reg_name2wrapper[reg_name].static_parameters
47
            if 'topic_names' in d:
48
                del d['topic_names']
49
            return d
50
        return pprint.pformat({
51
            reg_long_type: dict(_filter(reg_name),
52
                                **{k:v for k,v in {'target topics': (lambda x: 'all' if len(x) == 0 else '[{}]'.
53
                                                                               format(', '.join(x)))(self.get_reg_obj(reg_name).topic_names),
54
                                                   'mods': getattr(self.get_reg_obj(reg_name), 'class_ids', None)}.items()}
55
                                )
56
            for reg_long_type, reg_name in self._reg_longtype2name.items()
57
        })
58
59
    @property
60
    def pformat_modalities(self):
61
        return pprint.pformat(self.modalities_dictionary)
62
63
    def initialize_regularizers(self, collection_passes, document_passes):
64
        """
65
        Call this method before starting of training to build some regularizers settings from run-time parameters.\n
66
        :param int collection_passes:
67
        :param int document_passes:
68
        """
69
        self._trajs = {}
70
        for reg_name, wrapper in self._reg_name2wrapper.items():
71
            self._trajs[reg_name] = wrapper.get_tau_trajectory(collection_passes)
72
            wrapper.set_alpha_iters_trajectory(document_passes)
73
74
    def get_reg_wrapper(self, reg_name):
75
        return self._reg_name2wrapper.get(reg_name, None)
76
77
    @property
78
    def regularizer_names(self):
79
        return sorted(_ for _ in self.artm_model.regularizers.data)
80
81
    @property
82
    def regularizer_wrappers(self):
83
        return list(map(lambda x: self._reg_name2wrapper[x], self.regularizer_names))
84
85
    @property
86
    def regularizer_types(self):
87
        return map(lambda x: self._reg_name2wrapper[x].type, self.regularizer_names)
88
89
    @property
90
    def regularizer_unique_types(self):
91
        return map(lambda x: self._reg_name2wrapper[x].long_type, self.regularizer_names)
92
93
    @property
94
    def long_types_n_types(self):
95
        return map(lambda x: (self._reg_name2wrapper[x].long_type, self._reg_name2wrapper[x].type), self.regularizer_names)
96
97
    @property
98
    def evaluator_names(self):
99
        return sorted(_ for _ in self.artm_model.scores.data)
100
101
    @property
102
    def evaluator_definitions(self):
103
        return [self._evaluator_name2definition[eval_name] for eval_name in self.evaluator_names]
104
105
    @property
106
    def evaluators(self):
107
        return [self._definition2evaluator[ev_def] for ev_def in self.evaluator_definitions]
108
109
    @property
110
    def tau_trajectories(self):
111
        return filter(lambda x: x[1] is not None, self._trajs.items())
112
113
    @property
114
    def nb_topics(self):
115
        return self.artm_model.num_topics
116
117
    @property
118
    def topic_names(self):
119
        return self.artm_model.topic_names
120
121
    @property
122
    def domain_topics(self):
123
        """Returns the mostly agreed list of topic names found in all evaluators"""
124
        c = Counter()
125
        for definition, evaluator in self._definition2evaluator.items():
126
            tn = getattr(evaluator, 'topic_names', None)
127
            if tn:
128
            # if hasattr(evaluator.artm_score, 'topic_names'):
129
                # print definition, evaluator.artm_score.topic_names
130
                # tn = evaluator.artm_score.topic_names
131
                # if tn:
132
                c['+'.join(tn)] += 1
133
        if len(c) > 2:
134
            warnings.warn("There exist {} different subsets of all the topic names targeted by evaluators".format(len(c)))
135
        # print c.most_common()
136
        try:
137
            return c.most_common(1)[0][0].split('+')
138
        except IndexError:
139
            raise IndexError("Failed to compute domain topic names. Please enable at least one score with topics=domain_names argument. Scores: {}".
140
                             format(['{}: {}'.format(x.name, x.settings) for x in self.evaluators]))
141
142
    @property
143
    def background_topics(self):
144
        return [topic_name for topic_name in self.artm_model.topic_names if topic_name not in self.domain_topics]
145
146
    @property
147
    def background_tokens(self):
148
        background_tokens_eval_name = ''
149
        for eval_def, eval_name in self.definition2evaluator_name.items():
150
            if eval_def.startswith('background-tokens-ratio-'):
151
                background_tokens_eval_name = eval_name
152
        if background_tokens_eval_name:
153
            res = self.artm_model.score_tracker[background_tokens_eval_name].tokens
154
            return list(res[-1])
155
156
    @property
157
    def modalities_dictionary(self):
158
        return self.artm_model.class_ids
159
160
    @property
161
    def document_passes(self):
162
        return self.artm_model.num_document_passes
163
164
    def get_reg_obj(self, reg_name):
165
        return self.artm_model.regularizers[reg_name]
166
167
    def get_reg_name(self, reg_type):
168
        return self._reg_longtype2name[reg_type]
169
        # try:
170
        #     return self._reg_longtype2name[reg_type]
171
        # except KeyError:
172
        #     print '{} is not found in {}'.format(reg_type, self._reg_longtype2name)
173
        #     import sys
174
        #     sys.exit(1)
175
176
    def get_evaluator(self, eval_name):
177
        return self._definition2evaluator[self._evaluator_name2definition[eval_name]]
178
179
    # def get_evaluator_by_def(self, definition):
180
    #     return self._definition2evaluator[definition]
181
    #
182
    # def get_scorer_by_name(self, eval_name):
183
    #     return self.artm_model.scores[eval_name]
184
185
    # def get_targeted_topics_per_evaluator(self):
186
    #     tps = []
187
    #     for evaluator in self.artm_model.scores.data:
188
    #         if hasattr(self.artm_model.scores[evaluator], 'topic_names'):
189
    #             tps.append((evaluator, self.artm_model.scores[evaluator].topic_names))
190
    #     return tps
191
192
    # def get_formated_topic_names_per_evaluator(self):
193
    #     return 'MODEL topic names:\n{}'.format('\n'.join(map(lambda x: ' {}: [{}]'.format(x[0], ', '.join(x[1])), self.get_targeted_topics_per_evaluator())))
194
195
    @document_passes.setter
196
    def document_passes(self, iterations):
197
        self.artm_model.num_document_passes = iterations
198
199
200
201
    def set_parameter(self, reg_name, reg_param, value):
202
        if reg_name in self.artm_model.regularizers.data:
203
            if hasattr(self.artm_model.regularizers[reg_name], reg_param):
204
                try:
205
                    # self.artm_model.regularizers[reg_name].__setattr__(reg_param, parameter_name2encoder[reg_param](value))
206
                    self.artm_model.regularizers[reg_name].__setattr__(reg_param, value)
207
                except (AttributeError, TypeError) as e:
208
                    print(e)
209
            else:
210
                raise ParameterNameNotFoundException("Regularizer '{}' with name '{}' has no attribute (parameter) '{}'".format(type(self.artm_model.regularizers[reg_name]).__name__, reg_name, reg_param))
211
        else:
212
            raise RegularizerNameNotFoundException("Did not find a regularizer in the artm_model with name '{}'".format(reg_name))
213
214
    def set_parameters(self, reg_name2param_settings):
215
        for reg_name, settings in reg_name2param_settings.items():
216
            for param, value in settings.items():
217
                self.set_parameter(reg_name, param, value)
218
219
    def get_regs_param_dict(self):
220
        """
221
        Returns a mapping between the model's regularizers unique string definition (ie shown in train.cfg: eg 'smooth-phi', 'sparse-theta',
222
        'label-regulariation-phi-cls', 'decorrelate-phi-dom-def') and their corresponding parameters that can be dynamically changed during training (eg 'tau', 'gamma').\n
223
        See patm.modeling.regularization.regularizers.REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH\n
224
        :return: the regularizer type (str) to parameters (list of strings) mapping (str => list)
225
        :rtype: dict
226
        """
227
        d = {}
228
        for unique_type, reg_type in self.long_types_n_types:
229
            d[unique_type] = {}
230
            cur_reg_obj = self.artm_model.regularizers[self._reg_longtype2name[unique_type]]
231
            for attribute_name in REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH[reg_type]:  # ie for _ in ('tau', 'gamma')
232
                d[unique_type][attribute_name] = getattr(cur_reg_obj, attribute_name)
233
        return d
234
    #
235
    # def _get_header(self, max_lens, topic_names):
236
    #     assert len(max_lens) == len(topic_names)
237
    #     return ' - '.join(map(lambda x: '{}{}'.format(x[1], ' ' * (max_lens[x[0]] - len(name))), (j, name in enumerate(topic_names))))
238
    #
239
    # def _get_rows(self, topic_name2tokens):
240
    #     max_token_lens = [max(map(lambda x: len(x), topic_name2tokens[name])) for name in self.artm_model.topic_names]
241
    #     b = ''
242
    #     for i in range(len(topic_name2tokens.values()[0])):
243
    #         b += ' | '.join('{} {}'.format(topic_name2tokens[name][i], (max_token_lens[j] - len(topic_name2tokens[name][i])) * ' ') for j, name in enumerate(self.artm_model.topic_names)) + '\n'
244
    #     return b, max_token_lens
245
246
247
class TrainSpecs(object):
248
    def __init__(self, collection_passes, reg_names, tau_trajectories):
249
        self._col_iter = collection_passes
250
        assert len(reg_names) == len(tau_trajectories)
251
        self._reg_name2tau_trajectory = dict(zip(reg_names, tau_trajectories))
252
        # print "TRAIN SPECS", self._col_iter, map(lambda x: len(x), self._reg_name2tau_trajectory.values())
253
        assert all(map(lambda x: len(x) == self._col_iter, self._reg_name2tau_trajectory.values()))
254
255
    def tau_trajectory(self, reg_name):
256
        return self._reg_name2tau_trajectory.get(reg_name, None)
257
258
    @property
259
    def tau_trajectory_list(self):
260
        """Returns the list of (reg_name, tau trajectory) pairs, sorted alphabetically by regularizer name"""
261
        return sorted(self._reg_name2tau_trajectory.items(), key=lambda x: x[0])
262
263
    @property
264
    def collection_passes(self):
265
        return self._col_iter
266
267
    def to_taus_slice(self, iter_count):
268
        return {reg_name: {'tau': trajectory[iter_count]} for reg_name, trajectory in self._reg_name2tau_trajectory.items()}
269
        # return dict(zip(self._reg_name2tau_trajectory.keys(), map(lambda x: x[iter_count], self._reg_name2tau_trajectory.values())))
270
271
272
class RegularizerNameNotFoundException(Exception):
273
    def __init__(self, msg):
274
        super(RegularizerNameNotFoundException, self).__init__(msg)
275
276
class ParameterNameNotFoundException(Exception):
277
    def __init__(self, msg):
278
        super(ParameterNameNotFoundException, self).__init__(msg)
279