1
|
|
|
import os |
2
|
|
|
from os import path |
3
|
|
|
import re |
4
|
|
|
from glob import glob |
5
|
|
|
from collections import defaultdict |
6
|
|
|
import attr |
7
|
|
|
from math import log |
8
|
|
|
from topic_modeling_toolkit.results.experimental_results import ExperimentalResults |
9
|
|
|
import artm |
10
|
|
|
import warnings |
11
|
|
|
import pandas as pd |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
import logging |
15
|
|
|
logger = logging.getLogger(__name__) |
16
|
|
|
|
17
|
|
|
############### COMPUTER ############################ |
18
|
|
|
@attr.s |
19
|
|
|
class DivergenceComputer(object): # In python 2 you MUST inherit from object to use @foo.setter feature! |
20
|
|
|
pct_models = attr.ib(init=False, default={}) |
21
|
|
|
__model = attr.ib(init=False) |
22
|
|
|
|
23
|
|
|
@property |
24
|
|
|
def psi(self): |
25
|
|
|
return self.pct_models[self.__model] |
26
|
|
|
|
27
|
|
|
@psi.setter |
28
|
|
|
def psi(self, psi_matrix): |
29
|
|
|
if psi_matrix.label not in self.pct_models: |
30
|
|
|
self.pct_models[psi_matrix.label] = {'obj': psi_matrix, 'distances': {}} |
31
|
|
|
self.__model = psi_matrix.label |
32
|
|
|
|
33
|
|
|
def __call__(self, *args, **kwargs): |
34
|
|
|
self._first_class = args[0] |
35
|
|
|
self._rest_classes = args[1:] |
36
|
|
|
return [self.symmetric_KL(self._first_class, c, kwargs['topics']) for c in self._rest_classes] |
37
|
|
|
|
38
|
|
|
def get_symmetric_KL(self, class1, class2, topics): |
39
|
|
|
return self.psi['distances'].get('{}-{}'.format(class1, class2), |
40
|
|
|
self.psi['distances'].get('{}-{}'.format(class2, class1), |
41
|
|
|
self.symmetric_KL(class1, class2, topics))) |
42
|
|
|
|
43
|
|
|
def symmetric_KL(self, class1, class2, topics): |
44
|
|
|
s = 0 |
45
|
|
|
for topic in topics: |
46
|
|
|
s += self._point_sKL(class1, class2, topic) |
47
|
|
|
self.psi['distances']['{}-{}'.format(class1, class2)] = s |
48
|
|
|
return s |
49
|
|
|
|
50
|
|
|
def _point_sKL(self, c1, c2, topic): |
51
|
|
|
if self.p_ct(c1, topic) == 0 or self.p_ct(c2, topic) == 0: |
52
|
|
|
logger.warning( |
53
|
|
|
"One of p(c|t) is zero: [{:.3f}, {:.3f}]. Skipping topic '{}' from the summation (over topics) of the symmetric KL formula, because none of limits [x->0], [y->0], [x,y->0] exist.".format( |
54
|
|
|
self.p_ct(c1, topic), self.p_ct(c2, topic), topic)) |
55
|
|
|
return 0 |
56
|
|
|
return self._point_KL(c1, c2, topic) + self._point_KL(c2, c1, topic) |
57
|
|
|
|
58
|
|
|
def _point_KL(self, c1, c2, topic): |
59
|
|
|
return self.p_ct(c1, topic) * log(float(self.p_ct(c1, topic) / self.p_ct(c2, topic))) |
60
|
|
|
|
61
|
|
|
def p_ct(self, c, t): |
62
|
|
|
""" |
63
|
|
|
Probability of class=c given topic=t: p(class=c|topic=t)\n |
64
|
|
|
:param str c: |
65
|
|
|
:param str or int t: |
66
|
|
|
:return: |
67
|
|
|
:rtype: float |
68
|
|
|
""" |
69
|
|
|
return self.psi['obj'].p_ct(c, t) |
70
|
|
|
|
71
|
|
|
|
72
|
|
|
################# REPORTER #############################3 |
73
|
|
|
|
74
|
|
|
@attr.s |
75
|
|
|
class PsiReporter(object): |
76
|
|
|
datasets = attr.ib(init=True, default={}) |
77
|
|
|
_dataset_path = attr.ib(init=True, default='') |
78
|
|
|
_topics_extractors = attr.ib(init=False, default={'all': lambda x: x.domain_topics + x.background_topics, |
79
|
|
|
'domain': lambda x: x.domain_topics, |
80
|
|
|
'background': lambda x: x.background_topics}) |
81
|
|
|
|
82
|
|
|
discoverable_class_modality_names = attr.ib(init=True, default=['@labels_class', '@ideology_class']) |
83
|
|
|
computer = attr.ib(init=False, default=DivergenceComputer()) |
84
|
|
|
# dataset_name = attr.ib(init=False, default=attr.Factory(lambda self: path.basename(self._dataset_path), takes_self=True)) |
85
|
|
|
has_registered_class_names = {} |
86
|
|
|
# models = attr.ib(init=False, default=) self._selected_topics |
87
|
|
|
_precision = attr.ib(init=False, default=2) |
88
|
|
|
_psi = attr.ib(init=False, default='') |
89
|
|
|
_selected_topics = attr.ib(init=False, default=[]) |
90
|
|
|
|
91
|
|
|
@property |
92
|
|
|
def dataset(self): |
93
|
|
|
return self.datasets[self._dataset_path] |
94
|
|
|
|
95
|
|
|
@dataset.setter |
96
|
|
|
def dataset(self, dataset_path): |
97
|
|
|
if dataset_path not in self.datasets: |
98
|
|
|
self.datasets[dataset_path] = DatasetCollection(dataset_path, self.discoverable_class_modality_names) |
99
|
|
|
self._dataset_path = dataset_path |
100
|
|
|
if not hasattr(self.datasets[dataset_path], 'class_names'): |
101
|
|
|
raise RuntimeError( |
102
|
|
|
"A dataset '{}' object found without 'class_names' attribute".format(self.datasets[dataset_path].name)) |
103
|
|
|
logger.info("Dataset '{}' at {}".format(self.datasets[dataset_path].name, self.datasets[dataset_path].dir_path)) |
104
|
|
|
# logger.info("{}".format(str(self.datasets[dataset_path]))) |
105
|
|
|
if not self.datasets[dataset_path].doc_labeling_modality_name: |
106
|
|
|
logger.warning("Dataset's '{}' vocabulary file has no registered tokens representing document class label names".format(self.datasets[dataset_path].name)) |
107
|
|
|
|
108
|
|
|
@property |
109
|
|
|
def psi_matrix(self): |
110
|
|
|
return self._psi |
111
|
|
|
|
112
|
|
|
@psi_matrix.setter |
113
|
|
|
def psi_matrix(self, psi_matrix): |
114
|
|
|
if len(self.dataset.class_names) != psi_matrix.shape[0]: |
115
|
|
|
raise RuntimeError( |
116
|
|
|
"Number of classes do not correspond to the number of rows of Psi matrix. Found {} registered 'class names' tokens: [{}]. Psi matrix number of rows (classes) = {}.". |
117
|
|
|
format(len(self.dataset.class_names), self.dataset.class_names, psi_matrix.shape[0])) |
118
|
|
|
if len(self._topic_names) != psi_matrix.shape[1]: |
119
|
|
|
raise RuntimeError( |
120
|
|
|
"Number of topics in experimental results do not correspond to the number of columns rows of the Psi matrix. Found {} topics, while number of columns = {}". |
121
|
|
|
format(len(self._topic_names), psi_matrix.shape[1])) |
122
|
|
|
self._psi = psi_matrix |
123
|
|
|
|
124
|
|
|
@property |
125
|
|
|
def topics(self): |
126
|
|
|
"""The selected topics to sum over when computing the symmetric KL divergence""" |
127
|
|
|
return self._selected_topics |
128
|
|
|
|
129
|
|
|
@topics.setter |
130
|
|
|
def topics(self, topics): |
131
|
|
|
"""The selected topics to sum over when computing the symmetric KL divergence""" |
132
|
|
|
if type(topics) == str: |
133
|
|
|
topics = self._topics_extractors[topics](self.exp_res.scalars) |
134
|
|
|
if not all(x in self._topic_names for x in topics): |
135
|
|
|
raise RuntimeError("Not all the topic names given [{}] are in the defined topics [{}] of the input model '{}'".format(', '.join(topics), ', '.join(self._topic_names), self.exp_res.scalars.model_label)) |
136
|
|
|
self._selected_topics = topics |
137
|
|
|
|
138
|
|
|
def pformat(self, model_paths, topics_set='domain', show_model_name=True, show_class_names=True, precision=2): |
139
|
|
|
self._precision = precision |
140
|
|
|
b = [] |
141
|
|
|
if self.dataset.doc_labeling_modality_name: |
142
|
|
|
for phi_path, json_path in self._all_paths(model_paths): |
143
|
|
|
# model_label = path.basename(json_path) |
144
|
|
|
# logger.info("Phi model '{}', experimentsl results '{}".format(phi_path, json_path)) |
145
|
|
|
print("Phi model '{}', experimentsl results '{}".format(phi_path, json_path)) |
146
|
|
|
model = self.artifacts(phi_path, json_path) |
147
|
|
|
is_WTDC_model = any(x in self.exp_res.scalars.modalities for x in self.discoverable_class_modality_names) |
148
|
|
|
if is_WTDC_model: |
149
|
|
|
self.topics = topics_set |
150
|
|
|
# if not self.dataset.doc_labeling_modality_name: |
151
|
|
|
# warnings.warn("The document class modality (one of [{}]) was found in experimental results '{}', but dataset's vocabulary file '{}' does not contain registered tokens representing the unique document classes and thus phi matrix ( p(c|t) probabilities ) where probably not computed during training.".format(', '.join(sorted(self.discoverable_class_modality_names)), path.basename(json_path), path.basename(self.dataset.vocab_file))) |
152
|
|
|
# else: |
153
|
|
|
|
154
|
|
|
self.psi_matrix = PsiMatrix.from_artm(model, self.dataset.doc_labeling_modality_name) |
155
|
|
|
# if len(self.dataset.class_names) != self.psi.shape[0]: |
156
|
|
|
# raise RuntimeError("Number of classes do not correspond to the number of rows of Psi matrix. Found {} registered 'class names' tokens: [{}]. Psi matrix number of rows (classes) = {}.". |
157
|
|
|
# format(len(self.dataset.class_names), self.dataset.class_names, self.psi.shape[0])) |
158
|
|
|
# |
159
|
|
|
# if len(self._topic_names) != self.psi.shape[1]: |
160
|
|
|
# raise RuntimeError( |
161
|
|
|
# "Number of topics in experimental results do not correspond to the number of columns rows of the Psi matrix. Found {} topics, while number of columns = {}". |
162
|
|
|
# format(len(self._topic_names), self.psi.shape[1])) |
163
|
|
|
b.append(self.divergence_str(topics_set=topics_set, show_model_name=show_model_name, show_class_names=show_class_names)) |
164
|
|
|
else: |
165
|
|
|
print("Skipping model '{}' since it does not utilize any document metadata, such as document labels".format(path.basename(phi_path.replace('.phi', '')))) |
166
|
|
|
return '\n\n'.join(b) |
167
|
|
|
|
168
|
|
|
def values(self, model_paths, topics_set='domain'): |
169
|
|
|
""" |
170
|
|
|
:param model_paths: |
171
|
|
|
:param topics_set: |
172
|
|
|
:return: list of lists of lists |
173
|
|
|
""" |
174
|
|
|
list_of_lists = [] |
175
|
|
|
if self.dataset.doc_labeling_modality_name: |
176
|
|
|
for phi_path, json_path in self._all_paths(model_paths): |
177
|
|
|
logger.info("Phi model '{}', experimentsl results '{}".format(phi_path, json_path)) |
178
|
|
|
model = self.artifacts(phi_path, json_path) |
179
|
|
|
is_WTDC_model = any( |
180
|
|
|
x in self.exp_res.scalars.modalities for x in self.discoverable_class_modality_names) |
181
|
|
|
if is_WTDC_model: |
182
|
|
|
self.topics = topics_set |
183
|
|
|
self.psi_matrix = PsiMatrix.from_artm(model, self.dataset.doc_labeling_modality_name) |
184
|
|
|
self.psi_matrix.label = self.exp_res.scalars.model_label |
185
|
|
|
self.computer.psi = self.psi_matrix |
186
|
|
|
self.computer.class_names = self.dataset.class_names |
187
|
|
|
list_of_lists.append([self._values(i, c) for i, c in enumerate(self.dataset.class_names)]) |
188
|
|
|
else: |
189
|
|
|
logger.info( |
190
|
|
|
"Skipping model '{}' since it does not utilize any document metadata, such as document labels".format( |
191
|
|
|
path.basename(phi_path.replace('.phi', '')))) |
192
|
|
|
return list_of_lists |
193
|
|
|
|
194
|
|
|
def artifacts(self, *args): |
195
|
|
|
self.exp_res = ExperimentalResults.create_from_json_file(args[1]) |
196
|
|
|
self._topic_names = self.exp_res.scalars.domain_topics + self.exp_res.scalars.background_topics |
197
|
|
|
_artm = artm.ARTM(topic_names=self.exp_res.scalars.domain_topics + self.exp_res.scalars.background_topics, dictionary=self.dataset.lexicon, show_progress_bars=False) |
198
|
|
|
_artm.load(args[0]) |
199
|
|
|
return _artm |
200
|
|
|
|
201
|
|
|
def _all_paths(self, model_paths): |
202
|
|
|
for m in model_paths: |
203
|
|
|
yield self.paths(m) |
204
|
|
|
|
205
|
|
|
def paths(self, *args): |
206
|
|
|
if os.path.isfile(args[0]): # is a full path to .phi file |
207
|
|
|
return args[0], path.join(path.dirname(args[0]), '../results', path.basename(args[0]).replace('.phi', '.json')) |
208
|
|
|
return os.path.join(self._dataset_path, 'models', args[0]), path.join(self._dataset_path, 'results', args[0]).replace('.phi', '.json') # input is model label |
209
|
|
|
|
210
|
|
|
###### STRING BUILDING |
211
|
|
|
def divergence_str(self, topics_set='domain', show_model_name=True, show_class_names=True): |
212
|
|
|
self._show_class_names = show_class_names |
213
|
|
|
|
214
|
|
|
self._reportable_class_strings = list(map(lambda x: x, self.dataset.class_names)) |
215
|
|
|
self.__max_class_len = max(len(x) for x in self._reportable_class_strings) |
216
|
|
|
self._psi.label = self.exp_res.scalars.model_label |
217
|
|
|
self.computer.psi = self._psi |
218
|
|
|
self.computer.class_names = self.dataset.class_names |
219
|
|
|
|
220
|
|
|
|
221
|
|
|
string_values = [[self._str(x) for x in self._values(i, c)] for i, c in enumerate(self.dataset.class_names)] |
222
|
|
|
self.__max_len = max(max(len(x) for x in y) for y in string_values) |
223
|
|
|
_ = ''.join('{}\n'.format(self._pct_row(i, strings)) for i, strings in enumerate(string_values)) |
224
|
|
|
if show_model_name: |
225
|
|
|
return "{}\n{}".format(self.exp_res.scalars.model_label, _) |
226
|
|
|
return _ |
227
|
|
|
|
228
|
|
|
def _values(self, index, class_name): |
229
|
|
|
distances = list(self.computer(*list([class_name] + self.dataset.class_names[:index] + self.dataset.class_names[index + 1:]), topics=self._selected_topics)) |
230
|
|
|
distances.insert(index, 0) |
231
|
|
|
assert len(distances) == len(self.dataset.class_names) |
232
|
|
|
return distances |
233
|
|
|
|
234
|
|
|
def _pct_row(self, row_index, strings): |
235
|
|
|
if self._show_class_names: |
236
|
|
|
return '{}{} {}'.format(self._reportable_class_strings[row_index], |
237
|
|
|
' ' * (self.__max_class_len - len(self._reportable_class_strings[row_index])), |
238
|
|
|
' '.join('{}{}'.format(x, ' '*(self.__max_len - len(x))) for x in strings)) |
239
|
|
|
return ' '.join('{}{}'.format(x, ' '*(self.__max_len - len(x))) for x in strings) |
240
|
|
|
|
241
|
|
|
def _str(self, value): |
242
|
|
|
if value == 0: |
243
|
|
|
return '' |
244
|
|
|
return '{:.1f}'.format(value) |
245
|
|
|
# def _cooc_tf(self, *args): |
246
|
|
|
# if path.isfile(args[0]): # is a full path to .phi file e.match(r'^ppmi_(\d+)_([td]f)\.txt$', name) |
247
|
|
|
# c = glob('{}/ppmi_*\.txt'.format(path.join(os.path.dirname(args[0]), '../'))) |
248
|
|
|
# else: |
249
|
|
|
# c = glob('{}/ppmi_*\.txt'.format(self._dataset_path)) |
250
|
|
|
# if not c: |
251
|
|
|
# raise RuntimeError("Did not find any 'ppmi' files in dataset directory '{}'".format(path.dirname(args[0]), '../')) |
252
|
|
|
# return c[0] |
253
|
|
|
# |
254
|
|
|
# def _class_names(self, allowed_modality_names): |
255
|
|
|
# """Call this method to extract possible set of document class names out of the dataset's vocabulary file and the discoverred modality name serving to the p(c|t) \psi model. |
256
|
|
|
# Returns None if the dataset's vocabulary does not contain registered terms as the unique document class names""" |
257
|
|
|
# vocab_file = path.join(self._dataset_path, 'vocab.{}.txt'.format(self.dataset_name)) |
258
|
|
|
# with open(vocab_file, 'r') as f: |
259
|
|
|
# classname_n_modality_tuples = re.findall(r'^(\w+) ({})'.format('|'.join(allowed_modality_names)), f.read(), re.M) |
260
|
|
|
# if not classname_n_modality_tuples: |
261
|
|
|
# return [], '' |
262
|
|
|
# modalities = set([modality_name for _, modality_name in classname_n_modality_tuples]) |
263
|
|
|
# if len(modalities) > 1: |
264
|
|
|
# raise ValueError("More than one candidate modalities found to serve as the document classification scheme: [{}]".format(sorted())) |
265
|
|
|
# document_classes = [class_name for class_name, _ in classname_n_modality_tuples] |
266
|
|
|
# if len(document_classes) > 6: |
267
|
|
|
# warnings.warn("Detected {} classes for dataset '{}'. Perhaps too many classes for a collection of {} documents. You can define a different discretization scheme (binning of the political spectrum)".format(len(document_classes), self.dataset_name, self.nb_docs)) |
268
|
|
|
# return document_classes, modalities.pop() |
269
|
|
|
# |
270
|
|
|
# @property |
271
|
|
|
# def nb_docs(self): |
272
|
|
|
# return self.file_len(pth.join(self._dataset_path, 'vowpal.{}.txt'.format(self.dataset_name))) |
273
|
|
|
# |
274
|
|
|
# def file_len(self, file_path): |
275
|
|
|
# with open(file_path) as f: |
276
|
|
|
# return len([None for i, _ in enumerate(f)]) |
277
|
|
|
|
278
|
|
|
##################### PSI MATRIX ############################# |
279
|
|
|
|
280
|
|
|
def _valid_probs(instance, attribute, value): |
281
|
|
|
for i, topic in enumerate(value): |
282
|
|
|
topic_specific_class_probabilities = [value[topic][x] for x in range(len(value[topic]))] |
283
|
|
|
try: |
284
|
|
|
assert abs(sum(topic_specific_class_probabilities) - 1) < 0.001 |
285
|
|
|
except AssertionError: |
286
|
|
|
raise RuntimeError("{}: [{}] sum: {} abs-diff-with-zero: {}".format(topic, ', '.join('{:.2f}'.format(x) for x in topic_specific_class_probabilities), sum(topic_specific_class_probabilities), abs(sum(topic_specific_class_probabilities) - 1))) |
287
|
|
|
|
288
|
|
|
@attr.s |
289
|
|
|
class PsiMatrix(object): |
290
|
|
|
"""Class x Topics matrix holdig p(c|t) probabilities \forall c \in C and t \in T""" |
291
|
|
|
dataframe = attr.ib(init=True, validator=_valid_probs) |
292
|
|
|
shape = attr.ib(init=False, default=attr.Factory(lambda self: self.dataframe.shape, takes_self=True)) |
293
|
|
|
|
294
|
|
|
def __str__(self): |
295
|
|
|
return str(self.dataframe) |
296
|
|
|
|
297
|
|
|
def iter_topics(self): |
298
|
|
|
return (topic_name for topic_name in self.dataframe) |
|
|
|
|
299
|
|
|
|
300
|
|
|
def iterrows(self): |
301
|
|
|
return self.dataframe.iterrows() |
302
|
|
|
|
303
|
|
|
def itercolumns(self): |
304
|
|
|
return self.dataframe.iteritems() |
305
|
|
|
|
306
|
|
|
def p_ct(self, c, t): |
307
|
|
|
""" |
308
|
|
|
Probability of class=c given topic=t: p(class=c|topic=t)\n |
309
|
|
|
:param str c: |
310
|
|
|
:param str or int t: |
311
|
|
|
:return: |
312
|
|
|
:rtype: float |
313
|
|
|
""" |
314
|
|
|
return self.dataframe.loc[c][t] |
315
|
|
|
|
316
|
|
|
def classes_distribution(self, topic): |
317
|
|
|
"""Probabilities of classes conditioned on topic; p(c|topic=topic)\n |
318
|
|
|
:param str topic: |
319
|
|
|
:return: the p(c|topic) probabilities as an integer-indexable object |
320
|
|
|
:rtype: pandas.core.series.Series |
321
|
|
|
""" |
322
|
|
|
return self.dataframe[topic] |
323
|
|
|
|
324
|
|
|
@classmethod |
325
|
|
|
def from_artm(cls, artm_model, modality_name): |
326
|
|
|
phi = artm_model.get_phi() |
327
|
|
|
psi_matrix = phi.set_index(pd.MultiIndex.from_tuples(phi.index)).loc[modality_name] |
328
|
|
|
return PsiMatrix(psi_matrix) |
329
|
|
|
|
330
|
|
|
|
331
|
|
|
########################## DATASET ########################### |
332
|
|
|
|
333
|
|
|
def _id_dir(instance, attribute, value): |
334
|
|
|
if not path.isdir(value): |
335
|
|
|
raise IOError("'{}' is not a valid directory path".format(value)) |
336
|
|
|
|
337
|
|
|
def _class_names(self, attribute, value): |
338
|
|
|
"""Call this method to extract possible set of document class names out of the dataset's vocabulary file and the discoverred modality name serving to the p(c|t) \psi model. |
339
|
|
|
Returns None if the dataset's vocabulary does not contain registered terms as the unique document class names""" |
340
|
|
|
vocab_file = path.join(self.dir_path, 'vocab.{}.txt'.format(self.name)) |
341
|
|
|
with open(vocab_file, 'r') as f: |
342
|
|
|
classname_n_modality_tuples = re.findall(r'(\w+)[\t\ ]({})'.format('|'.join(x for x in self.allowed_modality_names)), f.read()) |
343
|
|
|
|
344
|
|
|
if not classname_n_modality_tuples: |
345
|
|
|
self.class_names = [] |
346
|
|
|
self.doc_labeling_modality_name = '' |
347
|
|
|
else: |
348
|
|
|
modalities = set([modality_name for _, modality_name in classname_n_modality_tuples]) |
349
|
|
|
if len(modalities) > 1: |
350
|
|
|
raise ValueError("More than one candidate modalities found to serve as the document classification scheme: [{}]".format(sorted(x for x in modalities))) |
351
|
|
|
document_classes = [class_name for class_name, _ in classname_n_modality_tuples] |
352
|
|
|
warn_threshold = 8 |
353
|
|
|
if len(document_classes) > warn_threshold: |
354
|
|
|
warnings.warn("Detected {} classes for dataset '{}'. Perhaps too many classes for a collection of {} documents. You can define a different discretization scheme (binning of the political spectrum)".format(len(document_classes), self.name, self.nb_docs)) |
355
|
|
|
self.class_names = document_classes |
356
|
|
|
self.doc_labeling_modality_name = modalities.pop() |
357
|
|
|
|
358
|
|
|
|
359
|
|
|
def _file_len(file_path): |
360
|
|
|
with open(file_path) as f: |
361
|
|
|
return len([None for i, _ in enumerate(f)]) |
362
|
|
|
|
363
|
|
|
|
364
|
|
|
@attr.s |
365
|
|
|
class DatasetCollection(object): |
366
|
|
|
dir_path = attr.ib(init=True, converter=str, validator=_id_dir, repr=True) |
367
|
|
|
|
368
|
|
|
allowed_modality_names = attr.ib(init=True, default=['@labels_class', '@ideology_class']) |
369
|
|
|
name = attr.ib(init=False, default=attr.Factory(lambda self: path.basename(self.dir_path), takes_self=True)) |
370
|
|
|
vocab_file = attr.ib(init=False, default=attr.Factory(lambda self: path.join(self.dir_path, 'vocab.{}.txt'.format(self.name)), takes_self=True)) |
371
|
|
|
lexicon = attr.ib(init=False, default=attr.Factory(lambda self: artm.Dictionary(name=self.name), takes_self=True)) |
372
|
|
|
doc_labeling_modality_name = attr.ib(init=False, default='') |
373
|
|
|
class_names = attr.ib(init=False, default=[], validator=_class_names) |
374
|
|
|
# nb_docs = attr.ib(init=False, default=attr.Factory(lambda self: _file_len(path.join(self.dir_path, 'vowpal.{}.txt'.format(self.name))), takes_self=True)) |
375
|
|
|
ppmi_file = attr.ib(init=False, default=attr.Factory(lambda self: self._cooc_tf(), takes_self=True)) |
376
|
|
|
|
377
|
|
|
def __attrs_post_init__(self): |
378
|
|
|
self.lexicon.gather(data_path=self.dir_path, |
379
|
|
|
cooc_file_path=self.ppmi_file, |
380
|
|
|
vocab_file_path=self.vocab_file, |
381
|
|
|
symmetric_cooc_values=True) |
382
|
|
|
|
383
|
|
|
def _cooc_tf(self): |
384
|
|
|
c = glob('{}/ppmi_*tf.txt'.format(self.dir_path)) |
385
|
|
|
if not c: |
386
|
|
|
raise RuntimeError("Did not find any 'ppmi' (computed with simple 'tf' scheme) files in dataset directory '{}'".format(self.dir_path)) |
387
|
|
|
return c[0] |
388
|
|
|
|