reporting.psi.DivergenceComputer.p_ct()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 3
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import os
2
from os import path
3
import re
4
from glob import glob
5
from collections import defaultdict
6
import attr
7
from math import log
8
from topic_modeling_toolkit.results.experimental_results import ExperimentalResults
9
import artm
10
import warnings
11
import pandas as pd
12
13
14
import logging
15
logger = logging.getLogger(__name__)
16
17
############### COMPUTER ############################
18
@attr.s
19
class DivergenceComputer(object):  # In python 2 you MUST inherit from object to use @foo.setter feature!
20
    pct_models = attr.ib(init=False, default={})
21
    __model = attr.ib(init=False)
22
23
    @property
24
    def psi(self):
25
        return self.pct_models[self.__model]
26
27
    @psi.setter
28
    def psi(self, psi_matrix):
29
        if psi_matrix.label not in self.pct_models:
30
            self.pct_models[psi_matrix.label] = {'obj': psi_matrix, 'distances': {}}
31
        self.__model = psi_matrix.label
32
33
    def __call__(self, *args, **kwargs):
34
        self._first_class = args[0]
35
        self._rest_classes = args[1:]
36
        return [self.symmetric_KL(self._first_class, c, kwargs['topics']) for c in self._rest_classes]
37
38
    def get_symmetric_KL(self, class1, class2, topics):
39
        return self.psi['distances'].get('{}-{}'.format(class1, class2),
40
                                         self.psi['distances'].get('{}-{}'.format(class2, class1),
41
                                                                   self.symmetric_KL(class1, class2, topics)))
42
43
    def symmetric_KL(self, class1, class2, topics):
44
        s = 0
45
        for topic in topics:
46
            s += self._point_sKL(class1, class2, topic)
47
        self.psi['distances']['{}-{}'.format(class1, class2)] = s
48
        return s
49
50
    def _point_sKL(self, c1, c2, topic):
51
        if self.p_ct(c1, topic) == 0 or self.p_ct(c2, topic) == 0:
52
            logger.warning(
53
                "One of p(c|t) is zero: [{:.3f}, {:.3f}]. Skipping topic '{}' from the summation (over topics) of the symmetric KL formula, because none of limits [x->0], [y->0], [x,y->0] exist.".format(
54
                    self.p_ct(c1, topic), self.p_ct(c2, topic), topic))
55
            return 0
56
        return self._point_KL(c1, c2, topic) + self._point_KL(c2, c1, topic)
57
58
    def _point_KL(self, c1, c2, topic):
59
        return self.p_ct(c1, topic) * log(float(self.p_ct(c1, topic) / self.p_ct(c2, topic)))
60
61
    def p_ct(self, c, t):
62
        """
63
        Probability of class=c given topic=t: p(class=c|topic=t)\n
64
        :param str c:
65
        :param str or int t:
66
        :return:
67
        :rtype: float
68
        """
69
        return self.psi['obj'].p_ct(c, t)
70
71
72
################# REPORTER #############################3
73
74
@attr.s
75
class PsiReporter(object):
76
    datasets = attr.ib(init=True, default={})
77
    _dataset_path = attr.ib(init=True, default='')
78
    _topics_extractors = attr.ib(init=False, default={'all': lambda x: x.domain_topics + x.background_topics,
79
                                          'domain': lambda x: x.domain_topics,
80
                                          'background': lambda x: x.background_topics})
81
82
    discoverable_class_modality_names = attr.ib(init=True, default=['@labels_class', '@ideology_class'])
83
    computer = attr.ib(init=False, default=DivergenceComputer())
84
    # dataset_name = attr.ib(init=False, default=attr.Factory(lambda self: path.basename(self._dataset_path), takes_self=True))
85
    has_registered_class_names = {}
86
    # models = attr.ib(init=False, default=) self._selected_topics
87
    _precision = attr.ib(init=False, default=2)
88
    _psi = attr.ib(init=False, default='')
89
    _selected_topics = attr.ib(init=False, default=[])
90
91
    @property
92
    def dataset(self):
93
        return self.datasets[self._dataset_path]
94
95
    @dataset.setter
96
    def dataset(self, dataset_path):
97
        if dataset_path not in self.datasets:
98
            self.datasets[dataset_path] = DatasetCollection(dataset_path, self.discoverable_class_modality_names)
99
        self._dataset_path = dataset_path
100
        if not hasattr(self.datasets[dataset_path], 'class_names'):
101
            raise RuntimeError(
102
                "A dataset '{}' object found without 'class_names' attribute".format(self.datasets[dataset_path].name))
103
        logger.info("Dataset '{}' at {}".format(self.datasets[dataset_path].name, self.datasets[dataset_path].dir_path))
104
        # logger.info("{}".format(str(self.datasets[dataset_path])))
105
        if not self.datasets[dataset_path].doc_labeling_modality_name:
106
            logger.warning("Dataset's '{}' vocabulary file has no registered tokens representing document class label names".format(self.datasets[dataset_path].name))
107
108
    @property
109
    def psi_matrix(self):
110
        return self._psi
111
112
    @psi_matrix.setter
113
    def psi_matrix(self, psi_matrix):
114
        if len(self.dataset.class_names) != psi_matrix.shape[0]:
115
            raise RuntimeError(
116
                "Number of classes do not correspond to the number of rows of Psi matrix. Found {} registered 'class names' tokens: [{}]. Psi matrix number of rows (classes) = {}.".
117
                format(len(self.dataset.class_names), self.dataset.class_names, psi_matrix.shape[0]))
118
        if len(self._topic_names) != psi_matrix.shape[1]:
119
            raise RuntimeError(
120
                "Number of topics in experimental results do not correspond to the number of columns rows of the Psi matrix. Found {} topics, while number of columns = {}".
121
                    format(len(self._topic_names), psi_matrix.shape[1]))
122
        self._psi = psi_matrix
123
124
    @property
125
    def topics(self):
126
        """The selected topics to sum over when computing the symmetric KL divergence"""
127
        return self._selected_topics
128
129
    @topics.setter
130
    def topics(self, topics):
131
        """The selected topics to sum over when computing the symmetric KL divergence"""
132
        if type(topics) == str:
133
            topics = self._topics_extractors[topics](self.exp_res.scalars)
134
        if not all(x in self._topic_names for x in topics):
135
            raise RuntimeError("Not all the topic names given [{}] are in the defined topics [{}] of the input model '{}'".format(', '.join(topics), ', '.join(self._topic_names), self.exp_res.scalars.model_label))
136
        self._selected_topics = topics
137
138
    def pformat(self, model_paths, topics_set='domain', show_model_name=True, show_class_names=True, precision=2):
139
        self._precision = precision
140
        b = []
141
        if self.dataset.doc_labeling_modality_name:
142
            for phi_path, json_path in self._all_paths(model_paths):
143
                # model_label = path.basename(json_path)
144
                # logger.info("Phi model '{}', experimentsl results '{}".format(phi_path, json_path))
145
                print("Phi model '{}', experimentsl results '{}".format(phi_path, json_path))
146
                model = self.artifacts(phi_path, json_path)
147
                is_WTDC_model = any(x in self.exp_res.scalars.modalities for x in self.discoverable_class_modality_names)
148
                if is_WTDC_model:
149
                    self.topics = topics_set
150
                    # if not self.dataset.doc_labeling_modality_name:
151
# warnings.warn("The document class modality (one of [{}]) was found in experimental results '{}', but dataset's vocabulary file '{}' does not contain registered tokens representing the unique document classes and thus phi matrix ( p(c|t) probabilities ) where probably not computed during training.".format(', '.join(sorted(self.discoverable_class_modality_names)), path.basename(json_path), path.basename(self.dataset.vocab_file)))
152
                    # else:
153
154
                    self.psi_matrix = PsiMatrix.from_artm(model, self.dataset.doc_labeling_modality_name)
155
                    # if len(self.dataset.class_names) != self.psi.shape[0]:
156
                    #     raise RuntimeError("Number of classes do not correspond to the number of rows of Psi matrix. Found {} registered 'class names' tokens: [{}]. Psi matrix number of rows (classes) = {}.".
157
                    #                        format(len(self.dataset.class_names), self.dataset.class_names, self.psi.shape[0]))
158
                    #
159
                    # if len(self._topic_names) != self.psi.shape[1]:
160
                    #     raise RuntimeError(
161
                    #         "Number of topics in experimental results do not correspond to the number of columns rows of the Psi matrix. Found {} topics, while number of columns = {}".
162
                    #         format(len(self._topic_names), self.psi.shape[1]))
163
                    b.append(self.divergence_str(topics_set=topics_set, show_model_name=show_model_name, show_class_names=show_class_names))
164
                else:
165
                    print("Skipping model '{}' since it does not utilize any document metadata, such as document labels".format(path.basename(phi_path.replace('.phi', ''))))
166
        return '\n\n'.join(b)
167
168
    def values(self, model_paths, topics_set='domain'):
169
        """
170
        :param model_paths:
171
        :param topics_set:
172
        :return: list of lists of lists
173
        """
174
        list_of_lists = []
175
        if self.dataset.doc_labeling_modality_name:
176
            for phi_path, json_path in self._all_paths(model_paths):
177
                logger.info("Phi model '{}', experimentsl results '{}".format(phi_path, json_path))
178
                model = self.artifacts(phi_path, json_path)
179
                is_WTDC_model = any(
180
                    x in self.exp_res.scalars.modalities for x in self.discoverable_class_modality_names)
181
                if is_WTDC_model:
182
                    self.topics = topics_set
183
                    self.psi_matrix = PsiMatrix.from_artm(model, self.dataset.doc_labeling_modality_name)
184
                    self.psi_matrix.label = self.exp_res.scalars.model_label
185
                    self.computer.psi = self.psi_matrix
186
                    self.computer.class_names = self.dataset.class_names
187
                    list_of_lists.append([self._values(i, c) for i, c in enumerate(self.dataset.class_names)])
188
                else:
189
                    logger.info(
190
                        "Skipping model '{}' since it does not utilize any document metadata, such as document labels".format(
191
                            path.basename(phi_path.replace('.phi', ''))))
192
        return list_of_lists
193
194
    def artifacts(self, *args):
195
        self.exp_res = ExperimentalResults.create_from_json_file(args[1])
196
        self._topic_names = self.exp_res.scalars.domain_topics + self.exp_res.scalars.background_topics
197
        _artm = artm.ARTM(topic_names=self.exp_res.scalars.domain_topics + self.exp_res.scalars.background_topics, dictionary=self.dataset.lexicon, show_progress_bars=False)
198
        _artm.load(args[0])
199
        return _artm
200
201
    def _all_paths(self, model_paths):
202
        for m in model_paths:
203
            yield self.paths(m)
204
205
    def paths(self, *args):
206
        if os.path.isfile(args[0]):  # is a full path to .phi file
207
            return args[0], path.join(path.dirname(args[0]), '../results', path.basename(args[0]).replace('.phi', '.json'))
208
        return os.path.join(self._dataset_path, 'models', args[0]), path.join(self._dataset_path, 'results', args[0]).replace('.phi', '.json')  # input is model label
209
210
    ###### STRING BUILDING
211
    def divergence_str(self, topics_set='domain', show_model_name=True, show_class_names=True):
212
        self._show_class_names = show_class_names
213
214
        self._reportable_class_strings = list(map(lambda x: x, self.dataset.class_names))
215
        self.__max_class_len = max(len(x) for x in self._reportable_class_strings)
216
        self._psi.label = self.exp_res.scalars.model_label
217
        self.computer.psi = self._psi
218
        self.computer.class_names = self.dataset.class_names
219
220
221
        string_values = [[self._str(x) for x in self._values(i, c)] for i, c in enumerate(self.dataset.class_names)]
222
        self.__max_len = max(max(len(x) for x in y) for y in string_values)
223
        _ = ''.join('{}\n'.format(self._pct_row(i, strings)) for i, strings in enumerate(string_values))
224
        if show_model_name:
225
            return "{}\n{}".format(self.exp_res.scalars.model_label, _)
226
        return _
227
228
    def _values(self, index, class_name):
229
        distances = list(self.computer(*list([class_name] + self.dataset.class_names[:index] + self.dataset.class_names[index + 1:]), topics=self._selected_topics))
230
        distances.insert(index, 0)
231
        assert len(distances) == len(self.dataset.class_names)
232
        return distances
233
234
    def _pct_row(self, row_index, strings):
235
        if self._show_class_names:
236
            return '{}{} {}'.format(self._reportable_class_strings[row_index],
237
                                    ' ' * (self.__max_class_len - len(self._reportable_class_strings[row_index])),
238
                                    ' '.join('{}{}'.format(x, ' '*(self.__max_len - len(x))) for x in strings))
239
        return ' '.join('{}{}'.format(x, ' '*(self.__max_len - len(x))) for x in strings)
240
241
    def _str(self, value):
242
        if value == 0:
243
            return ''
244
        return '{:.1f}'.format(value)
245
# def _cooc_tf(self, *args):
246
    #     if path.isfile(args[0]):  # is a full path to .phi file e.match(r'^ppmi_(\d+)_([td]f)\.txt$', name)
247
    #         c = glob('{}/ppmi_*\.txt'.format(path.join(os.path.dirname(args[0]), '../')))
248
    #     else:
249
    #         c = glob('{}/ppmi_*\.txt'.format(self._dataset_path))
250
    #     if not c:
251
    #         raise RuntimeError("Did not find any 'ppmi' files in dataset directory '{}'".format(path.dirname(args[0]), '../'))
252
    #     return c[0]
253
    #
254
    # def _class_names(self, allowed_modality_names):
255
    #     """Call this method to extract possible set of document class names out of the dataset's vocabulary file and the discoverred modality name serving to the p(c|t) \psi model.
256
    #         Returns None if the dataset's vocabulary does not contain registered terms as the unique document class names"""
257
    #     vocab_file = path.join(self._dataset_path, 'vocab.{}.txt'.format(self.dataset_name))
258
    #     with open(vocab_file, 'r') as f:
259
    #         classname_n_modality_tuples = re.findall(r'^(\w+) ({})'.format('|'.join(allowed_modality_names)), f.read(), re.M)
260
    #         if not classname_n_modality_tuples:
261
    #             return [], ''
262
    #         modalities = set([modality_name for _, modality_name in classname_n_modality_tuples])
263
    #         if len(modalities) > 1:
264
    #             raise ValueError("More than one candidate modalities found to serve as the document classification scheme: [{}]".format(sorted()))
265
    #         document_classes = [class_name for class_name, _ in classname_n_modality_tuples]
266
    #         if len(document_classes) > 6:
267
    #             warnings.warn("Detected {} classes for dataset '{}'. Perhaps too many classes for a collection of {} documents. You can define a different discretization scheme (binning of the political spectrum)".format(len(document_classes), self.dataset_name, self.nb_docs))
268
    #         return document_classes, modalities.pop()
269
    #
270
    # @property
271
    # def nb_docs(self):
272
    #     return self.file_len(pth.join(self._dataset_path, 'vowpal.{}.txt'.format(self.dataset_name)))
273
    #
274
    # def file_len(self, file_path):
275
    #     with open(file_path) as f:
276
    #         return len([None for i, _ in enumerate(f)])
277
278
##################### PSI MATRIX #############################
279
280
def _valid_probs(instance, attribute, value):
281
    for i, topic in enumerate(value):
282
        topic_specific_class_probabilities = [value[topic][x] for x in range(len(value[topic]))]
283
        try:
284
            assert abs(sum(topic_specific_class_probabilities) - 1) < 0.001
285
        except AssertionError:
286
            raise RuntimeError("{}: [{}] sum: {} abs-diff-with-zero: {}".format(topic, ', '.join('{:.2f}'.format(x) for x in topic_specific_class_probabilities), sum(topic_specific_class_probabilities), abs(sum(topic_specific_class_probabilities) - 1)))
287
288
@attr.s
289
class PsiMatrix(object):
290
    """Class x Topics matrix holdig p(c|t) probabilities \forall c \in C and t \in T"""
291
    dataframe = attr.ib(init=True, validator=_valid_probs)
292
    shape = attr.ib(init=False, default=attr.Factory(lambda self: self.dataframe.shape, takes_self=True))
293
294
    def __str__(self):
295
        return str(self.dataframe)
296
297
    def iter_topics(self):
298
        return (topic_name for topic_name in self.dataframe)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable topic_name does not seem to be defined.
Loading history...
299
300
    def iterrows(self):
301
        return self.dataframe.iterrows()
302
303
    def itercolumns(self):
304
        return self.dataframe.iteritems()
305
306
    def p_ct(self, c, t):
307
        """
308
        Probability of class=c given topic=t: p(class=c|topic=t)\n
309
        :param str c:
310
        :param str or int t:
311
        :return:
312
        :rtype: float
313
        """
314
        return self.dataframe.loc[c][t]
315
316
    def classes_distribution(self, topic):
317
        """Probabilities of classes conditioned on topic; p(c|topic=topic)\n
318
        :param str topic:
319
        :return: the p(c|topic) probabilities as an integer-indexable object
320
        :rtype: pandas.core.series.Series
321
        """
322
        return self.dataframe[topic]
323
324
    @classmethod
325
    def from_artm(cls, artm_model, modality_name):
326
        phi = artm_model.get_phi()
327
        psi_matrix = phi.set_index(pd.MultiIndex.from_tuples(phi.index)).loc[modality_name]
328
        return PsiMatrix(psi_matrix)
329
330
331
########################## DATASET ###########################
332
333
def _id_dir(instance, attribute, value):
334
    if not path.isdir(value):
335
        raise IOError("'{}' is not a valid directory path".format(value))
336
337
def _class_names(self, attribute, value):
338
    """Call this method to extract possible set of document class names out of the dataset's vocabulary file and the discoverred modality name serving to the p(c|t) \psi model.
339
        Returns None if the dataset's vocabulary does not contain registered terms as the unique document class names"""
340
    vocab_file = path.join(self.dir_path, 'vocab.{}.txt'.format(self.name))
341
    with open(vocab_file, 'r') as f:
342
        classname_n_modality_tuples = re.findall(r'(\w+)[\t\ ]({})'.format('|'.join(x for x in self.allowed_modality_names)), f.read())
343
344
        if not classname_n_modality_tuples:
345
            self.class_names = []
346
            self.doc_labeling_modality_name = ''
347
        else:
348
            modalities = set([modality_name for _, modality_name in classname_n_modality_tuples])
349
            if len(modalities) > 1:
350
                raise ValueError("More than one candidate modalities found to serve as the document classification scheme: [{}]".format(sorted(x for x in modalities)))
351
            document_classes = [class_name for class_name, _ in classname_n_modality_tuples]
352
            warn_threshold = 8
353
            if len(document_classes) > warn_threshold:
354
                warnings.warn("Detected {} classes for dataset '{}'. Perhaps too many classes for a collection of {} documents. You can define a different discretization scheme (binning of the political spectrum)".format(len(document_classes), self.name, self.nb_docs))
355
            self.class_names = document_classes
356
            self.doc_labeling_modality_name = modalities.pop()
357
358
359
def _file_len(file_path):
360
    with open(file_path) as f:
361
        return len([None for i, _ in enumerate(f)])
362
363
364
@attr.s
365
class DatasetCollection(object):
366
    dir_path = attr.ib(init=True, converter=str, validator=_id_dir, repr=True)
367
368
    allowed_modality_names = attr.ib(init=True, default=['@labels_class', '@ideology_class'])
369
    name = attr.ib(init=False, default=attr.Factory(lambda self: path.basename(self.dir_path), takes_self=True))
370
    vocab_file = attr.ib(init=False, default=attr.Factory(lambda self: path.join(self.dir_path, 'vocab.{}.txt'.format(self.name)), takes_self=True))
371
    lexicon = attr.ib(init=False, default=attr.Factory(lambda self: artm.Dictionary(name=self.name), takes_self=True))
372
    doc_labeling_modality_name = attr.ib(init=False, default='')
373
    class_names = attr.ib(init=False, default=[], validator=_class_names)
374
    # nb_docs = attr.ib(init=False, default=attr.Factory(lambda self: _file_len(path.join(self.dir_path, 'vowpal.{}.txt'.format(self.name))), takes_self=True))
375
    ppmi_file = attr.ib(init=False, default=attr.Factory(lambda self: self._cooc_tf(), takes_self=True))
376
377
    def __attrs_post_init__(self):
378
        self.lexicon.gather(data_path=self.dir_path,
379
                                       cooc_file_path=self.ppmi_file,
380
                                       vocab_file_path=self.vocab_file,
381
                                       symmetric_cooc_values=True)
382
383
    def _cooc_tf(self):
384
        c = glob('{}/ppmi_*tf.txt'.format(self.dir_path))
385
        if not c:
386
            raise RuntimeError("Did not find any 'ppmi' (computed with simple 'tf' scheme) files in dataset directory '{}'".format(self.dir_path))
387
        return c[0]
388