1
|
|
|
from os import path |
2
|
|
|
import artm |
3
|
|
|
from collections import OrderedDict |
4
|
|
|
from configparser import ConfigParser |
5
|
|
|
|
6
|
|
|
from topic_modeling_toolkit.patm.utils import cfg2model_settings |
7
|
|
|
from topic_modeling_toolkit.patm.definitions import DEFAULT_CLASS_NAME, IDEOLOGY_CLASS_NAME # this is the name of the default modality. it is irrelevant to class lebels or document lcassification |
8
|
|
|
|
9
|
|
|
from .regularizers import ArtmRegularizerWrapper |
10
|
|
|
|
11
|
|
|
import logging |
12
|
|
|
logger = logging.getLogger(__name__) |
13
|
|
|
|
14
|
|
|
import attr |
15
|
|
|
|
16
|
|
|
my_dir = path.dirname(path.realpath(__file__)) |
17
|
|
|
REGULARIZERS_CFG = path.join(my_dir, './regularizers.cfg') |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
set1 = ('tau', 'gamma', 'class_ids', 'topic_names') |
21
|
|
|
set2 = ('tau', 'topic_names', 'alpha_iter', 'doc_titles', 'doc_topic_coef') |
22
|
|
|
decorrelation = ('tau', 'gamma', 'class_ids', 'topic_names', 'topic_pairs') |
23
|
|
|
regularizer2parameters = { |
24
|
|
|
'smooth-phi': set1, |
25
|
|
|
'sparse-phi': set1, |
26
|
|
|
'smooth-theta': set2, |
27
|
|
|
'sparse-theta': set2, |
28
|
|
|
'decorrelate-phi': decorrelation, |
29
|
|
|
# 'decorrelate-phi-def': decorrelation, # default token modality |
30
|
|
|
# 'decorrelate-phi-class': decorrelation, # doc class modality |
31
|
|
|
# 'decorrelate-phi-domain': decorrelation, # all modalities |
32
|
|
|
# 'decorrelate-phi-background': decorrelation, # all modalities |
33
|
|
|
'label-regularization-phi': set1, |
34
|
|
|
# 'kl-function-info': ('function_type', 'power_value'), |
35
|
|
|
# 'specified-sparse-phi': ('tau', 'gamma', 'topic_names', 'class_id', 'num_max_elements', 'probability_threshold', 'sparse_by_column'), |
36
|
|
|
'improve-coherence': ('tau', 'gamma', 'class_ids', 'topic_names'), |
37
|
|
|
# 'smooth-ptdw': ('tau', 'topic_names', 'alpha_iter'), |
38
|
|
|
# 'topic-selection': ('tau', 'topic_names', 'alpha_iter') |
39
|
|
|
# the reasonable parameters to consider experiment with setting to different values and to also change between training cycles |
40
|
|
|
} |
41
|
|
|
|
42
|
|
|
|
43
|
|
|
# parameters that can be dynamically updated during training ie tau coefficients for sparsing regularizers should increase in absolute values (start below zero and decrease ) |
44
|
|
|
supported_dynamic_parameters = ('tau', 'gamma') |
45
|
|
|
|
46
|
|
|
REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH = dict([(k, [_ for _ in v if _ in supported_dynamic_parameters]) for k, v in regularizer2parameters.items()]) |
47
|
|
|
|
48
|
|
|
|
49
|
|
|
parameter_name2encoder = { |
50
|
|
|
'tau': str, # the coefficient of regularization for this regularizer |
51
|
|
|
'start': str, # the number of iterations to initially keep the regularizer turned off |
52
|
|
|
'gamma': float, # the coefficient of relative regularization for this regularizer |
53
|
|
|
'alpha_iter': str, # list of additional coefficients of regularization on each iteration over document. Should have length equal to model.num_document_passes (of an artm.ARTM model object) |
54
|
|
|
# 'class_ids': list, # list of class_ids or single class_id to regularize, will regularize all classes if empty or None [ONLY for ImproveCoherencePhiRegularizer: dictionary should contain pairwise tokens co-occurrence info] |
55
|
|
|
# 'topic_names': list, # specific topics to target for applying regularization, Targets all if None |
56
|
|
|
# 'topic_pairs': dict, # key=topic_name, value=dict: pairwise topic decorrelation coefficients, if None all values equal to 1.0 |
57
|
|
|
# 'function_type': str, # the type of function, 'log' (logarithm) or 'pol' (polynomial) |
58
|
|
|
# 'power_value': float, # the float power of polynomial, ignored if 'function_type' = 'log' |
59
|
|
|
# 'doc_titles': list, # of strings: list of titles of documents to be processed by this regularizer. Default empty value means processing of all documents. User should guarantee the existence and correctness of document titles in batches (e.g. in src files with data, like WV). |
60
|
|
|
# 'doc_topic_coef': list, # (list of floats or list of list of floats): of floats len=nb_topics or of lists of floats len=len(doc_titles) and len(inner_list)=_nb_topics: Two cases: 1) list of floats with length equal to num of topics. Means additional multiplier in M-step formula besides alpha and tau, unique for each topic, but general for all processing documents. 2) list of lists of floats with outer list length equal to length of doc_titles, and each inner list length equal to num of topics. Means case 1 with unique list of additional multipliers for each document from doc_titles. Other documents will not be regularized according to description of doc_titles parameter. Note, that doc_topic_coef and topic_names are both using. |
61
|
|
|
# 'num_max_elements': int, # number of elements to save in row/column for the artm.SpecifiedSparsePhiRegularizer |
62
|
|
|
# 'probability_threshold': float, # if m elements in row/column sum into value >= probability_threshold, m < n => only these elements would be saved. Value should be in (0, 1), default=None |
63
|
|
|
# 'sparse_by_columns': bool, # find max elements in column or row |
64
|
|
|
} |
65
|
|
|
|
66
|
|
|
|
67
|
|
|
class RegularizersFactory: |
68
|
|
|
"""Supports construction of Smoothing, Sparsing, Label, Decorrelating, Coherence-improving regularizers. Currently, it |
69
|
|
|
resorts to supplying the internal artm.Dictionary to the regularizers' constructors whenever possible.""" |
70
|
|
|
|
71
|
|
|
def __init__(self, dictionary): |
72
|
|
|
""" |
73
|
|
|
:param artm.Dictionary dictionary: this object shall be passed as regularizers' constructors argument when possible |
74
|
|
|
""" |
75
|
|
|
# :param str regularizers_initialization_parameters: this file contains parameters and values to use as defaults for when initializing regularizer object |
76
|
|
|
self._dictionary = dictionary |
77
|
|
|
self._regs_data = None |
78
|
|
|
self._reg_settings = {} |
79
|
|
|
self._back_t, self._domain_t = [], [] |
80
|
|
|
self._regularizer_type2constructor = \ |
81
|
|
|
{'smooth-phi': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._back_t, [DEFAULT_CLASS_NAME]), |
82
|
|
|
'smooth-theta': lambda x: ArtmRegularizerWrapper.create('smooth-theta', x, self._back_t), |
83
|
|
|
'sparse-phi': lambda x: ArtmRegularizerWrapper.create('sparse-phi', x, self._domain_t, [DEFAULT_CLASS_NAME]), |
84
|
|
|
'sparse-theta': lambda x: ArtmRegularizerWrapper.create('sparse-theta', x, self._domain_t), |
85
|
|
|
'smooth-phi-dom-cls': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._domain_t, [IDEOLOGY_CLASS_NAME]), |
86
|
|
|
'smooth-phi-bac-cls': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._back_t, [IDEOLOGY_CLASS_NAME]), |
87
|
|
|
'smooth-phi-cls': lambda x: ArtmRegularizerWrapper.create('smooth-phi', x, self._back_t + self._domain_t, [IDEOLOGY_CLASS_NAME]), |
88
|
|
|
'label-regularization-phi-dom-cls': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
89
|
|
|
self._domain_t, |
90
|
|
|
dictionary=self._dictionary, |
91
|
|
|
class_ids=IDEOLOGY_CLASS_NAME), |
92
|
|
|
'decorrelate-phi-dom-def': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi', x, self._domain_t, |
93
|
|
|
class_ids=DEFAULT_CLASS_NAME), |
94
|
|
|
'label-regularization-phi-dom-all': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, self._domain_t, |
95
|
|
|
dictionary=self._dictionary, |
96
|
|
|
class_ids=None), # targets all classes, since no CLASS_LABELS list is given |
97
|
|
|
'label-regularization-phi-bac-all': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
98
|
|
|
self._back_t, |
99
|
|
|
dictionary=self._dictionary, |
100
|
|
|
class_ids=None), |
101
|
|
|
'label-regularization-phi-dom-def': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
102
|
|
|
self._domain_t, |
103
|
|
|
dictionary=self._dictionary, |
104
|
|
|
class_ids=[DEFAULT_CLASS_NAME]), |
105
|
|
|
'label-regularization-phi-bac-def': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
106
|
|
|
self._back_t, |
107
|
|
|
dictionary=self._dictionary, |
108
|
|
|
class_ids=DEFAULT_CLASS_NAME), |
109
|
|
|
|
110
|
|
|
'label-regularization-phi-bac-cls': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
111
|
|
|
self._back_t, |
112
|
|
|
dictionary=self._dictionary, |
113
|
|
|
class_ids=IDEOLOGY_CLASS_NAME), |
114
|
|
|
'label-regularization-phi-all': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
115
|
|
|
self._domain_t + self._back_t, |
116
|
|
|
dictionary=self._dictionary, |
117
|
|
|
class_ids=None), |
118
|
|
|
'label-regularization-phi-def': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
119
|
|
|
self._domain_t + self._back_t, |
120
|
|
|
dictionary=self._dictionary, |
121
|
|
|
class_ids=DEFAULT_CLASS_NAME), |
122
|
|
|
'label-regularization-phi-cls': lambda x: ArtmRegularizerWrapper.create('label-regularization-phi', x, |
123
|
|
|
self._domain_t + self._back_t, |
124
|
|
|
dictionary=self._dictionary, |
125
|
|
|
class_ids=IDEOLOGY_CLASS_NAME), |
126
|
|
|
'decorrelate-phi-def': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi-def', x, self._domain_t, class_ids=DEFAULT_CLASS_NAME), |
127
|
|
|
'decorrelate-phi-class': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi-class', x, self._domain_t, |
128
|
|
|
class_ids=IDEOLOGY_CLASS_NAME), |
129
|
|
|
'decorrelate-phi-background': lambda x: ArtmRegularizerWrapper.create('decorrelate-phi', x, self._back_t, |
130
|
|
|
class_ids=None), |
131
|
|
|
'improve-coherence': lambda x: ArtmRegularizerWrapper.create('improve-coherence', x, self._domain_t, self._dictionary, |
132
|
|
|
class_ids=DEFAULT_CLASS_NAME)} |
133
|
|
|
logger.info("Initialized RegularizersFactory with artm.Dictionary '{}'".format(self._dictionary.name)) |
134
|
|
|
|
135
|
|
|
@property |
136
|
|
|
def regs_data(self): |
137
|
|
|
return self._regs_data |
138
|
|
|
|
139
|
|
|
@regs_data.setter |
140
|
|
|
def regs_data(self, regs_data): |
141
|
|
|
""" |
142
|
|
|
:param RegularizersData regs_data: |
143
|
|
|
""" |
144
|
|
|
self._regs_data = regs_data |
145
|
|
|
self._back_t, self._domain_t = regs_data.background_topics, regs_data.domain_topics |
146
|
|
|
self._reg_settings = regs_data.regularizers_parameters |
147
|
|
|
|
148
|
|
|
def create_reg_wrappers(self, reg_type2name, background_topics, domain_topics, reg_cfg=None): |
149
|
|
|
""" |
150
|
|
|
Creates a dict: each key is a regularizer type (identical to one of the '_regularizers_section_name2constructor' hash'\n |
151
|
|
|
:param str or list or dict reg_type2name: indicates which regularizers should be active; eg keys of the 'regularizers' section of the train.cfg |
152
|
|
|
- If type(reg_type2name) == str: reg_type2name is a file path to a cfg formated file that has a 'regularizers' section indicating the active regularization components.\n |
153
|
|
|
- If type(reg_type2name) == list: reg_type2name is a list of tuples with each 1st element being the regularizer type (eg 'smooth-phi', 'decorrelate-phi-domain') and each 2nd element being the regularizer unique name.\n |
154
|
|
|
- If type(reg_type2name) == dict: reg_type2name maps regularizer types to names. regularizer types and regularizer names |
155
|
|
|
:param list background_topics: a list of the 'background' topic names. Can be empty. |
156
|
|
|
:param list domain_topics: a list of the 'domain' topic names. Can be empty. |
157
|
|
|
:param str or dict reg_cfg: contains the values for initializing the regularizers' parameters (eg tau) with. If None then the default file is used |
158
|
|
|
- If type(reg_cfg) == str: reg_cfg is a file path to a cfg formated file that has as sections regularizer_types with their keys being initialization parameters\n |
159
|
|
|
- If type(reg_cfg) == dict: reg_cfg maps regularizer types to parameters dict. |
160
|
|
|
:rtype: RegularizersFactory |
161
|
|
|
""" |
162
|
|
|
if not reg_cfg: |
163
|
|
|
reg_cfg = REGULARIZERS_CFG |
164
|
|
|
self._regs_data = RegularizersData(background_topics, domain_topics, reg_type2name, reg_cfg) |
165
|
|
|
self._back_t, self._domain_t = background_topics, domain_topics |
166
|
|
|
self._reg_settings = self._regs_data.regularizers_parameters |
167
|
|
|
logger.info("Active regs: {}".format(self._regs_data.regs_hash)) |
168
|
|
|
logger.info("Regs default inits cfg file: {}".format(reg_cfg)) |
169
|
|
|
logger.info("Reg settings: {}".format(self._regs_data.regularizers_parameters)) |
170
|
|
|
return [self.construct_reg_wrapper(reg, params) for reg, params in sorted(self._regs_data.regularizers_parameters.items(), key=lambda x: x[0])] |
171
|
|
|
|
172
|
|
|
# def _create_reg_wrappers(self): |
173
|
|
|
# """ |
174
|
|
|
# Call this method to create all possible regularization components for the model.\n |
175
|
|
|
# keys; each value is a dictionry of the corresponding's regularizer's initialization parameter; ie: |
176
|
|
|
# {\n |
177
|
|
|
# 'sparse-phi': {'name': 'spp', 'start': 5, 'tau': 'linear_-0.5_-5'},\n |
178
|
|
|
# 'smooth-phi': {'name': 'smp', 'tau': 1},\n |
179
|
|
|
# 'sparse-theta': {'name': 'spt', 'start': 3, 'tau': linear_-0.2_-6, 'alpha_iter': 1},\n |
180
|
|
|
# 'smooth-theta': {'name': 'smt', 'tau': 1, 'alpha_iter': linear_0.5_1}\n |
181
|
|
|
# } |
182
|
|
|
# :return: the constructed regularizers; objects of type ArtmRegularizerWrapper |
183
|
|
|
# :rtype: list |
184
|
|
|
# """ |
185
|
|
|
# |
186
|
|
|
# return list(filter(None, map(lambda x: self.construct_reg_wrapper(x[0], x[1]), sorted(self._reg_settings.items(), key=lambda y: y[0])))) |
187
|
|
|
|
188
|
|
|
def construct_reg_wrapper(self, reg_type, settings): |
189
|
|
|
""" |
190
|
|
|
:param str reg_type: the regularizer's unique definition, based on reg_type, topics targeted, modality targeted |
191
|
|
|
:param dict settings: key, values pairs to initialize the regularizer parameters. Must contain 'name' key |
192
|
|
|
:return: the regularizer's wrapper object reference |
193
|
|
|
:rtype: ArtmRegularizerWrapper |
194
|
|
|
""" |
195
|
|
|
if reg_type not in self._regularizer_type2constructor: |
196
|
|
|
raise RuntimeError("Requested to create '{}' regularizer, which is not supported".format(reg_type)) |
197
|
|
|
if (self._back_t is None or len(self._back_t) == 0) and reg_type.startswith('smooth'): |
198
|
|
|
logger.warning("Requested to create '{}' regularizer, which normally targets 'bakground' topicts, but there are " |
199
|
|
|
"not distinct 'background' topics defined. The constructed regularizer will target all topics instead.".format(reg_type)) |
200
|
|
|
# manually insert the 'long_type' string in the settings hash to use it as the truly unique 'type' of a regularizer |
201
|
|
|
return self._regularizer_type2constructor[reg_type](dict(settings, **{'long-type': reg_type})) |
202
|
|
|
|
203
|
|
|
|
204
|
|
|
def _parse_active_regs(regs): |
205
|
|
|
reg_settings = {'OrderedDict': lambda x: x.items(), |
206
|
|
|
'dict': lambda x: x.items(), |
207
|
|
|
'str': lambda x: cfg2model_settings(x)['regularizers'], |
208
|
|
|
'list': lambda x: dict(_abbreviation(element) for element in x)} |
209
|
|
|
return dict(reg_settings[type(regs).__name__](regs)) # reg-def, reg-name tuples in a list |
210
|
|
|
|
211
|
|
|
def _abbreviation(reg_type): |
212
|
|
|
if type(reg_type) == str: |
213
|
|
|
r = reg_type.split('-') |
214
|
|
|
if len(r) == 1: |
215
|
|
|
return reg_type, reg_type |
216
|
|
|
if len(r) == 2: |
217
|
|
|
return reg_type, ''.join(x[:2] for x in r) |
218
|
|
|
return reg_type, ''.join(x[0] for x in r) |
219
|
|
|
elif type(reg_type) == tuple: |
220
|
|
|
return reg_type |
221
|
|
|
else: |
222
|
|
|
raise ValueError("Either input a string representing a regularizer type (immplying that the name should be inferred as an abbreviation) or a tuple holding both the type and the name") |
223
|
|
|
|
224
|
|
|
|
225
|
|
|
def cfg2regularizer_settings(cfg_file): |
226
|
|
|
config = ConfigParser() |
227
|
|
|
config.read(u'{}'.format(cfg_file)) |
228
|
|
|
return OrderedDict([(str(section), |
229
|
|
|
OrderedDict([(str(setting_name), |
230
|
|
|
parameter_name2encoder[str(setting_name)](value)) for setting_name, value in config.items(section) if value]) |
231
|
|
|
) for section in config.sections()]) |
232
|
|
|
|
233
|
|
|
|
234
|
|
|
def _parse_reg_cfg(regs_config): |
235
|
|
|
reg_initialization_type2_dict = {'dict': lambda x: x, |
236
|
|
|
'OrderedDict': lambda x: x, |
237
|
|
|
'str': lambda x: cfg2regularizer_settings(x)} |
238
|
|
|
return reg_initialization_type2_dict[type(regs_config).__name__](regs_config) |
239
|
|
|
|
240
|
|
|
|
241
|
|
|
def _create_reg_settings(self): |
242
|
|
|
regs_init_params = {} |
243
|
|
|
_settings = _parse_reg_cfg(self.reg_cfg) |
244
|
|
|
for reg_unique_type, reg_name in self.regs_hash.items(): |
245
|
|
|
try: |
246
|
|
|
regs_init_params[reg_unique_type] = dict(_settings[reg_unique_type], **{'name': reg_name}) |
247
|
|
|
except KeyError: |
248
|
|
|
raise KeyError("'reg_cfg' {} resulting in settings {}, does not have key '{}'. Probably you forgot to add the corresponding entry in 'train.cfg' and/or 'regularizers.cfg'".format(self.reg_cfg, _settings, reg_unique_type)) |
249
|
|
|
# "Keys in regs_cfg: [{}], keys requested as active regularizers: [{}], current key: {}, name: {}. " |
250
|
|
|
# "Probably you forgot to add the corresponding entry in 'train.cfg' and/or 'regularizers.cfg'".format( |
251
|
|
|
# ', '.join(sorted(self.reg_cfg.keys())), |
252
|
|
|
# ', '.join(sorted(self.reg_cfg.keys())), |
253
|
|
|
# reg_unique_type, reg_name)) |
254
|
|
|
return regs_init_params |
255
|
|
|
|
256
|
|
|
|
257
|
|
|
@attr.s(cmp=True, hash=True, slots=False) |
258
|
|
|
class RegularizersData(object): |
259
|
|
|
background_topics = attr.ib(init=True, converter=list, repr=True, cmp=True) |
260
|
|
|
domain_topics = attr.ib(init=True, converter=list, repr=True, cmp=True) |
261
|
|
|
regs_hash = attr.ib(init=True, converter=_parse_active_regs) |
262
|
|
|
reg_cfg = attr.ib(init=True) |
263
|
|
|
regularizers_parameters = attr.ib(default=attr.Factory(lambda self: _create_reg_settings(self), takes_self=True), repr=True, cmp=True, hash=True, init=False) |
264
|
|
|
|