Passed
Pull Request — dev (#2)
by
unknown
02:30
created

Topic.kernel_thresholds()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import re
2
import os
3
from math import ceil
4
from functools import reduce
5
6
from operator import itemgetter
7
import warnings
8
9
from topic_modeling_toolkit.results import ExperimentalResults
10
11
12
def parse_sort(sort_def):
13
    _ = re.search(r'^([a-zA-Z]+)-?(?:(0\.)?(\d\d?))?', sort_def).groups()
14
    return {'type': _[0], 'threshold': _threshold(_[1], _[2])}
15
16
def parse_tokens_type(tokens_type_def):
17
    _ = re.search(r'^((?:[a-zA-Z\-]+)*(?:[a-zA-Z]+))-?(?:(0\.)?(\d\d?))?', tokens_type_def).groups()
18
    return {'type': _[0], 'threshold': _threshold(_[1], _[2])}
19
20
21
def _threshold(el1, el2):
22
    if el1 is None and el2 is None:
23
        return None
24
    if len(el2) == 2:
25
        return '0.' + el2
26
    return '0.' + el2 + '0'
27
28
class Metric(object):
29
    def __new__(cls, *args, **kwargs):
30
        x = super(Metric, cls).__new__(cls)
31
        x._attr_getter = None
32
        x._reverse = True
33
        return x
34
    def __call__(self, *args, **kwargs):
35
        return self._attr_getter(args[0])
36
    def sort(self, topics):
37
        return sorted(topics, key=self._attr_getter, reverse=self._reverse)
38
class AlphabeticalOrder(Metric):
39
    def __init__(self):
40
        self._attr_getter = lambda x: x.name
41
        self._reverse = False
42
43
class KernelMetric(Metric):
44
    def __init__(self, threshold, metric_attribute, str_format='{:.1f}'):
45
        assert 0 < threshold < 1
46
        self._th, self._kernel_metric_attribute, self._frt = threshold, metric_attribute, str_format
47
        self._attr_getter = lambda x: getattr(getattr(x, 'kernel{}'.format(self._threshold(str(self._th)[2:4]))), self._kernel_metric_attribute)
48
    def pformat(self, topic):
49
        return '{}%'.format(self._frt.format(self(topic)*100))
50
51
    @staticmethod
52
    def _threshold(threshold):
53
        if len(threshold) == 1:
54
            return threshold + '0'
55
        return threshold
56
57
class MetricsContainer(object):
58
    abbrvs = {'coh': 'coherence', 'con': 'contrast', 'pur': 'purity'}
59
    def __init__(self):
60
        self._ab = AlphabeticalOrder()
61
    def __iter__(self):
62
        return iter(('coherence', 'contrast', 'purity'))
63
    def __call__(self, *args, **kwargs):
64
        return getattr(self, self.abbrvs.get(args[0], args[0]))(*args[1:])
65
    def name(self, *args):
66
        return self._ab
67
    def coherence(self, threshold):
68
        return KernelMetric(threshold, 'coherence')
69
    def contrast(self, threshold):
70
        return KernelMetric(threshold, 'contrast')
71
    def purity(self, threshold):
72
        return KernelMetric(threshold, 'purity')
73
    def __contains__(self, item):
74
        return item in iter(self) or item == 'name' or item in ['coh', 'con', 'pur']
75
76
metrics_container = MetricsContainer()
77
78
79
class TopicsHandler(object):
80
    sep = ' |'
81
    headers_sep = ' ' * len(sep)
82
    token_extractor = {'top-tokens': lambda x: x[0].top_tokens, # a topic and any other parameter that is not used
83
                       'kernel': lambda x: getattr(x[0], 'kernel'+x[1][2:]).tokens} # assumes topic_obj and threshold ie '0.90' tuple
84
    length_extractor = {'top-tokens': lambda x: len(x[0].top_tokens),
85
                        'kernel': lambda x: len(getattr(x[0], 'kernel'+ x[1][2:]))}  # assumes topic_obj and threshold ie '0.90' tuple
86
    metrics = metrics_container
87
88
    def __init__(self, collections_root_dir, results_dir_name='results'):
89
        self._cols = collections_root_dir
90
        self._res_dir_name = results_dir_name
91
        self._res, self._top_tokens_def = None, ''
92
        self._warn = []
93
        self._groups_type = ''
94
        self._model_label, self._threshold = '', ''
95
        # self._path2res = {}
96
        self._max_token_length_per_column = []
97
        self._model_topics_hash = {}
98
        self._max_tokens_per_row = []
99
        self._row_count = 0
100
        self._line_count = 0
101
        self._columns = 5
102
        self._topic_headers = []
103
        self._max_headers = []
104
105
    def _result_path(self, *args):
106
        """If 1 argument is given it is assumed to be a full path to results json file. If 2 arguments are given it is
107
        assumed that the 1st is a dataset label and the second a model label (eg 'plsa_1_3' not 'plsa_1_3.json')"""
108
        if os.path.isfile(args[0]):
109
            self._model_label = os.path.basename(args[0]).replace('.json', '')
110
            return args[0]
111
        self._model_label = args[1]
112
        return os.path.join(self._cols, args[0], 'results', '{}.json'.format(args[1]))
113
114
    def _model_topics(self, results_path):
115
        if results_path not in self._model_topics_hash:
116
            self._model_topics_hash[results_path] = ModelTopics(ExperimentalResults.create_from_json_file(results_path))
117
        return self._model_topics_hash[results_path]
118
119
    def pformat_background(self, model_results_path, columns=6, nb_tokens=100, show_title=False):
120
        """
121
        :param list model_results_path: if one element it should be a full path to a json file holding experimental results. If 2 elements, 1st i a dataset label and 2nd a model label (ie 'clda_1_2')
122
        :param int columns:
123
        :param int nb_tokens:
124
        :param bool show_title:
125
        :return:
126
        :rtype: str
127
        """
128
        _ = ExperimentalResults.create_from_json_file(self._result_path(*model_results_path))
129
        all_bg_tokens = _.final.background_tokens
130
        return self._pformat2(all_bg_tokens[:nb_tokens], columns=columns, title=(lambda x: _.scalars.model_label+' background tokens\n\n' if x else '')(show_title))
131
132
    @classmethod
133
    def _pformat2(cls, elements_list, columns=8, lines=20, title=''):
134
        columns = min(columns, int(ceil(len(elements_list) / lines)))
135
        max_token_p_col = [max(map(len, cls.column_slice(c, elements_list, lines, columns))) for c in range(columns)]
136
        b = title
137
        for r in range(int(ceil(len(elements_list) / (columns*lines)))):
138
            for l in range(lines):
139
                b += '{}\n'.format(cls.sep.join('{}{}'.format(tok_el, (max_token_p_col[i] - len(tok_el))*' ')
140
                             for i, tok_el in enumerate(cls.line_slice(r, l, elements_list, lines, columns))))
141
            b += '\n'
142
        return b
143
144
    @staticmethod
145
    def column_slice(column_index, elements_list, nb_lines, nb_columns):
146
        ent = range(int(ceil(len(elements_list) / (nb_columns*nb_lines))))  # number of "rows" each row has 'nb_lines' lines
147
        _ = [elements_list[nb_lines*(r*nb_columns + column_index):nb_lines*(r*nb_columns + column_index + 1)] for r in ent]
148
        return reduce(lambda i,j: i+j, _)
149
150
    @staticmethod
151
    def line_slice(r_index, l_index, all_bg_tokens, lines, columns):
152
        return all_bg_tokens[(r_index*lines*columns)+l_index:(r_index+1)*lines*columns:lines]
153
154
    def pformat(self, model_results_path, topics_set, tokens_type, sort, nb_tokens, columns, trim=0, topic_info=True, show_title=False):
155
        """
156
        :param list model_results_path:
157
        :param str topics_set:
158
        :param str tokens_type: accepts 'top-tokens' or 'kernel-|kernel' plus a threshold like '0.80', '0.6', '0.1234', '75', '5', '4321'
159
        :param str sort: {'name' (alphabetical), 'coherence-th', 'contrast-th', 'purity-th'} th i a \d\d pattern corresponding to kernel threshold
160
        :param int nb_tokens:
161
        :param int columns:
162
        :param bool show_title:
163
        :return:
164
        :rtype: str
165
        """
166
        self._columns = columns
167
        tokens_info = parse_tokens_type(tokens_type)
168
        sort_info = parse_sort(sort)
169
        assert tokens_info['type'] in ('top-tokens', 'kernel')
170
        if sort_info['type'] not in self.metrics:
171
            raise RuntimeError("Metric '{}' not supported. Use from [{}]".format(sort_info['type'], ', '.join("'{}'".format(x) for x in self.metrics)))
172
        if sort_info['threshold'] is None:
173
            if sort_info['type'] == 'name':
174
                if tokens_info['type'] == 'kernel' and tokens_info['threshold'] is None:
175
                    raise RuntimeError("Requested to show the kernel tokens per topic, but no threshold information was found in either the 'tokens_type' or the 'sort' parameters, which is required to target a specific kernel.")
176
            elif tokens_info['threshold'] is None:
177
                raise RuntimeError("Requested to sort topics according to a metric but no thresold was found in either the 'tokens_type' or the 'sort' parameters, which is required to target a specific kernel.")
178
            elif tokens_info['type'] == 'top-tokens':
179
                raise RuntimeError("Requested to sort topics according to a metric but no thresold was found in either the 'tokens_type' or the 'sort' parameters, which is required to target a specific kernel.")
180
        elif tokens_info['type'] == 'kernel' and tokens_info['threshold'] is not None and tokens_info['threshold'] != sort_info['threshold']:
181
            raise RuntimeError("The input token-type and sort-metric thresholds, {} and {} respectively, differ".format(tokens_info['threshold'], sort_info['threshold']))
182
183
        self._threshold = (lambda x: tokens_info['threshold'] if x is None else x)(sort_info['threshold'])
184
185
        model_topics = self._model_topics(self._result_path(*model_results_path))
186
        assert topics_set == 'domain'
187
        topics_set = getattr(model_topics, topics_set)  # either domain or background topics
188
        # self._available_thresolds = topics_set.thresholds
189
190
        th = self._threshold
191
        if th is not None:
192
            th = float(self._threshold)
193
        callable_metric = self.metrics(sort_info['type'], th)
194
        topics = callable_metric.sort(list(topics_set))
195
        if len(topics) < columns:
196
            self._columns = len(topics)
197
198
        self._max_token_length_per_column = [max([max(len(topic.name),
199
                                                      max(map(len,
200
                                                              list(self.token_extractor[tokens_info['type']]([topic, self._threshold]))[:nb_tokens or None])))
201
                                                  for topic in topics[i::self._columns]]) for i in range(self._columns)]
202
        self._topic_headers = list(self._gen_headers(topics, prefix='t', topic_info=topic_info))
203
        assert len(self._topic_headers) == len(topics)
204
        self._max_headers = [max(max(len(_) for _ in thd) for thd in self._topic_headers[c::self._columns]) for c in range(self._columns)]
205
206
        self._max_token_length_per_column = [max(x[0], x[1]) for x in zip(self._max_token_length_per_column, self._max_headers)]
207
        self._max_tokens_per_row = [max(self.length_extractor[tokens_info['type']]([t, self._threshold])
208
                                        for t in topics[k*self._columns:(k+1)*self._columns]) for k in range(int(ceil(len(topics) / self._columns)))]
209
        return self._pformat(topics, tokens_info['type'], nb_tokens, title=(lambda x: 'model: {}, tokens:{}, sort:{}\nrespective metrics: coherence, contrast, purity'.format(self._model_label, tokens_type, sort) + '\n\n' if x else '')(show_title))
210
211
    def _pformat(self, topics, tokens_type, nb_tokens, title=''):
212
        b = title
213
        for r in range(int(ceil(len(topics) / self._columns))):
214
            row_topics = topics[r*self._columns:(r+1)*self._columns]
215
            b += self._build_headers(r*self._columns, (r+1)*self._columns) + \
216
                 '\n'.join(self._line(l, tokens_type, row_topics) for l in range(min(self._max_tokens_per_row[r], nb_tokens))) + '\n'
217
            b += '\n'
218
        return b
219
220
    def _gen_headers(self, topics, prefix='t', topic_info=True):
221
        it = iter([prefix + re.search('^\w+(\d{2,3})$', t.name).group(1)] for t in topics)
222
        if topic_info and self._threshold is not None: # strictly must check against None
223
                return iter([next(it)[0], self._topic_metrics_header(t, float(self._threshold))] for t in topics)
224
        return it
225
226
    def _build_headers(self, start, stop):
227
        it = iter([r +' ' * (self._max_token_length_per_column[i] - len(r)) for r in item] for i, item in enumerate(self._topic_headers[slice(start, stop)]))
228
        e = reduce(lambda i, j: [ie + self.headers_sep + je for ie, je in zip(i, j)], it)
229
        return '{}\n'.format('\n'.join(_ for _ in e))
230
231
    @classmethod
232
    def _topic_metrics_header(cls, topic, threshold):
233
        """
234
        :param float threshold:
235
        """
236
        return ' '.join(getattr(cls.metrics, m)(threshold).pformat(topic) for m in cls.metrics)
237
238
    def _line(self, l_index, tokens_type, topics):
239
        return self.sep.join('{}{}'.format(tok_el, (self._max_token_length_per_column[i] - len(tok_el))*' ')
240
                             for i, tok_el in enumerate(self._gen_line_elements(l_index, tokens_type, topics)))
241
242
    def _gen_line_elements(self, l_index, tokens_type, topics):
243
        return iter(self._token(t, tokens_type, l_index) for t in topics)
244
245
    def _token(self, topic, tokens_type, l_index):
246
        if len(self._get_tokens_list(topic, tokens_type, threshold=self._threshold)) <= l_index:
247
            return ''
248
        return self._get_tokens_list(topic, tokens_type, threshold=self._threshold)[l_index]
249
250
    @classmethod
251
    def _get_tokens_list(cls, topic, tokens_type, threshold=''):
252
        return list(cls.token_extractor[tokens_type]([topic, threshold]))
253
254
255
class ModelTopics(object):
256
257
    @property
258
    def domain(self):
259
        return self._domain_topic_set
260
    @property
261
    def background_tokens(self):
262
        return self._background_tokens
263
264
    def __init__(self, experimental_results):
265
        """
266
267
        :param results.experimental_results.ExperimentalResults experimental_results:
268
        """
269
        max_top_tokens = max([int(_.split('-')[-1]) for _ in experimental_results.final.top_defs])
270
        if max_top_tokens < 10:
271
            warnings.warn("You probably wouldn't want to track less than the 10 most probable [p(w|t)] tokens per inferred topic. "
272
                          "This is achieved by enabling a score tracker: add a line such as 'top-tokens-10 = tops10' or "
273
                          "'top-tokens-100 = t100' under the [scores] section.")
274
        self._top_tokens_def = 'top{}'.format(max_top_tokens)
275
276
        self._domain_topic_set = self._create_topics_set('domain', experimental_results)
277
278
        self._background_tokens = experimental_results.final.background_tokens
279
280
    def _create_topics_set(self, topic_set_name, exp_res):
281
        topic_names_list = getattr(exp_res.scalars, '{}_topics'.format(topic_set_name))
282
        return TopicsSet(topic_set_name, [
283
            Topic(tn,
284
                  [{'threshold': float(th),
285
                    'tokens': self._get_tokens(getattr(exp_res.final, 'kernel{}'.format(th[2:])), tn),
286
                    'metrics': {m: getattr(getattr(getattr(exp_res.tracked, 'kernel{}'.format(th[2:])), tn), m).last for m in metrics_container}}
287
                    # 'metrics': {m: getattr(getattr(getattr(exp_res.tracked, 'kernel{}'.format(th)), tn), m).last for m in metrics_container}}
288
                   for th in exp_res.tracked.kernel_thresholds],
289
                  self._get_tokens(getattr(exp_res.final, self._top_tokens_def), tn))
290
            for tn in topic_names_list])
291
292
    def _get_tokens(self, final_obj, topic_name):
293
        return getattr(final_obj, topic_name).tokens
294
295
296
class TopicsSet(object):
297
    def __init__(self, name, topics):
298
        """
299
300
        :param str name:
301
        :param list topics:
302
        """
303
        self._name = name
304
        self._topics = {t.name: t for t in topics}
305
        if not all([topics[0].kernel_thresholds == x.kernel_thresholds for x in topics]):
306
            raise RuntimeError("Unexpectedly topics with different kernels defined found; the thresholds differ.")
307
        self._thresholds = topics[0].kernel_thresholds
308
309
    def __str__(self):
310
        return "{} Topics(nb_topics={})".format(self._name, len(self._topics))
311
    def __len__(self):
312
        return len(self._topics)
313
    def __getitem__(self, item):
314
        return self._topics[item]
315
    def __getattr__(self, item):
316
        return self._topics[item]
317
    def __contains__(self, item):
318
        return item in self._topics
319
    def __iter__(self):
320
        return iter(self._topics.values())
321
    @property
322
    def thresholds(self):
323
        return self._thresholds
324
    @property
325
    def name(self):
326
        return self._name
327
    @property
328
    def topic_names(self):
329
        return list(self._topics.keys())
330
    @property
331
    def topics_dict(self):
332
        return self._topics
333
334
335
class Topic(object):
336
    """
337
    Can be queried for its tokens with {'top_tokens' property to get a list of tokens sorted on p(w|t).\n
338
    Can be queried per defined kernel: eg topic_object.kernel60.tokens, topic_object.kernel80.tokens
339
                                       eg topic_object.kernel60.coherence, topic_object.kernel25.purity
340
    """
341
    class Kernel(object):
342
        def __init__(self, threshold, tokens, metrics):
343
            """
344
            :param float threshold:
345
            :param list tokens:
346
            :param dict metrics:
347
            """
348
            assert 0 < threshold <= 1
349
            self.threshold = threshold
350
            self._tokens = tokens
351
            self.metrics = metrics
352
        @property
353
        def coherence(self):
354
            return self.metrics['coherence']
355
        @property
356
        def contrast(self):
357
            return self.metrics['contrast']
358
        @property
359
        def purity(self):
360
            return self.metrics['purity']
361
        @property
362
        def tokens(self):
363
            return self._tokens
364
        def __len__(self):
365
            return len(self._tokens)
366
367
    def __init__(self, name, kernel_defs, top_tokens):
368
        """
369
        :param str name:
370
        :param list of Kernel kernel_defs: list of dicts like {'threshold': float, 'tokens': list, metrics: dict}
371
        :param list top_tokens:
372
        """
373
        self._name = name
374
        # the dict below has keys like ['60', '80', '25'] correponding to kernels with thresholds [0.6, 0.8, 0.25]
375
        self._kernel_info = {self._threshold_key(kr['threshold']): self.Kernel(kr['threshold'], kr['tokens'], kr['metrics']) for kr in kernel_defs}
376
        self._top_tokens = top_tokens
377
378
    def __str__(self):
379
        return "'{}': [{}]".format(self._name, ', '.join(['tt: '+str(len(self.top_tokens))] + ['{}: {}'.format(ko.threshold, len(ko.tokens)) for ko in self.kernel_objects]))
380
381
    def __repr__(self):
382
        return self.__str__()
383
384
    @staticmethod
385
    def _threshold_key(threshold):
386
        s = str(threshold)
387
        assert s.startswith('0.')
388
        if len(s[2:4]) < 2:
389
            return s[2:4] + '0'
390
        return s[2:4]
391
392
    @property
393
    def name(self):
394
        return self._name
395
396
    @property
397
    def top_tokens(self):
398
        return self._top_tokens
399
400
    def __getattr__(self, item):
401
        _ = item.replace('kernel', '')
402
        if _ not in self._kernel_info:
403
            raise AttributeError("Requested kernel with threshold 0.{} ({}), but the available ones are [{}].".format(_, item, ', '.join(self.kernel_names)))
404
        return self._kernel_info[_]
405
406
    @property
407
    def kernel_thresholds(self):
408
        return sorted(self._kernel_info.keys())
409
410
    @property
411
    def kernel_names(self):
412
        return map('kernel{}'.format, self.kernel_thresholds)
413
414
    @property
415
    def kernel_objects(self):
416
        return [self._kernel_info[th.replace('kernel', '')] for th in self.kernel_names]
417