|
1
|
|
|
import abc |
|
2
|
|
|
from warnings import warn |
|
3
|
|
|
|
|
4
|
|
|
import artm |
|
5
|
|
|
|
|
6
|
|
|
from .trajectory import TrajectoryBuilder |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
import logging |
|
10
|
|
|
logger = logging.getLogger(__name__) |
|
11
|
|
|
|
|
12
|
|
|
class ArtmRegularizerWrapper(object): |
|
13
|
|
|
__metaclass__ = abc.ABCMeta |
|
14
|
|
|
subclasses = {} |
|
15
|
|
|
labeling_ordering = ('tau', 'alpha_iter') |
|
16
|
|
|
_traj_type2traj_def_creator = {'alpha_iter': lambda x: '0_' + x[1], |
|
17
|
|
|
'tau': lambda x: '{}_{}'.format(x[0], x[1])} |
|
18
|
|
|
|
|
19
|
|
|
def __init__(self, parameters_dict, verbose=False): |
|
20
|
|
|
self.trajectory_builder = TrajectoryBuilder() |
|
21
|
|
|
self._regularizer = None |
|
22
|
|
|
self._alpha_iter_scalar = None |
|
23
|
|
|
self._trajectory_lambdas = {} |
|
24
|
|
|
self._traj_def = {} |
|
25
|
|
|
self._reg_constr_params = {} |
|
26
|
|
|
self._params_for_labeling = {} |
|
27
|
|
|
self._start = None |
|
28
|
|
|
self._name = parameters_dict.pop('name', 'no-name') |
|
29
|
|
|
self._long_type = parameters_dict.pop('long-type', 'type-not-found') |
|
30
|
|
|
|
|
31
|
|
|
if 'start' in parameters_dict: |
|
32
|
|
|
self._start = parameters_dict['start'] |
|
33
|
|
|
elif type(parameters_dict['tau']) == str: |
|
34
|
|
|
try: |
|
35
|
|
|
_ = float(parameters_dict['tau']) |
|
36
|
|
|
except ValueError: |
|
37
|
|
|
self._start = parameters_dict['tau'].split('_')[0] |
|
38
|
|
|
parameters_dict['tau'] = '_'.join(parameters_dict['tau'].split('_')[1:]) |
|
39
|
|
|
|
|
40
|
|
|
for k, v in parameters_dict.items(): |
|
41
|
|
|
# in case it is none then parameter it is handled by the default behaviour of artm |
|
42
|
|
|
if v is None or type(v) == list or type(v) == artm.dictionary.Dictionary or k == 'class_ids': # one of topic_names or class_ids or dictionary |
|
43
|
|
|
self._reg_constr_params[k] = v |
|
44
|
|
|
else: |
|
45
|
|
|
try: # by this point v should be a string, if exception occurs is shoud be only if a trajectory is defined (eg 'linear_-2_-10') |
|
46
|
|
|
vf = float(v) |
|
47
|
|
|
if k == 'alpha_iter': |
|
48
|
|
|
self._alpha_iter_scalar = vf # case: alpha_iter = a constant scalar which will be used for each of the 'nb_doument_passes' iterations |
|
49
|
|
|
else: |
|
50
|
|
|
self._reg_constr_params[k] = vf # case: parameter_name == 'tau' |
|
51
|
|
|
self._params_for_labeling[k] = vf |
|
52
|
|
|
except ValueError: |
|
53
|
|
|
if self._start is None: |
|
54
|
|
|
print("INFO Defaulting to activating the regularizer from the 1st iteration") |
|
55
|
|
|
self._start = 0 |
|
56
|
|
|
self._traj_def[k] = self._traj_type2traj_def_creator[k]([self._start, v]) # case: parameter_value is a trajectory definition without the 'start' setting (nb of initial iterations that regularizer stays inactive) |
|
57
|
|
|
self._params_for_labeling[k] = self._traj_def[k] |
|
58
|
|
|
self._create_artm_regularizer(dict(self._reg_constr_params, **{'name': self._name})) |
|
59
|
|
|
|
|
60
|
|
|
def _create_artm_regularizer(self, parameters): |
|
61
|
|
|
self._regularizer = self._artm_constructor(**parameters) |
|
62
|
|
|
logger.info("Built '{}'/'{}' reg, named '{}', with settings: {}".format(self.type, self._long_type, self._name, '{'+', '.join(map(lambda x: '{}={}'.format(x[0], x[1]), parameters.items()))+'}')) |
|
63
|
|
|
|
|
64
|
|
|
@classmethod |
|
65
|
|
|
def register_subclass(cls, regularizer_type): |
|
66
|
|
|
def decorator(subclass): |
|
67
|
|
|
cls.subclasses[regularizer_type] = subclass |
|
68
|
|
|
return subclass |
|
69
|
|
|
return decorator |
|
70
|
|
|
|
|
71
|
|
|
@classmethod |
|
72
|
|
|
def create(cls, regularizer_type, *args, **kwargs): |
|
73
|
|
|
if regularizer_type not in cls.subclasses: |
|
74
|
|
|
raise ValueError("Bad regularizer type '{}'".format(regularizer_type)) |
|
75
|
|
|
return cls.subclasses[regularizer_type](*args, **kwargs) |
|
76
|
|
|
|
|
77
|
|
|
@property |
|
78
|
|
|
def label(self): |
|
79
|
|
|
return '{}|{}'.format(self._name, '|'.join(map(lambda x: '{}:{}'.format(x[0][0], x[1]), self._get_labeling_data()))) |
|
80
|
|
|
|
|
81
|
|
|
def _get_labeling_data(self): |
|
82
|
|
|
return sorted(self._params_for_labeling.items(), key=lambda x: x[0]) |
|
83
|
|
|
|
|
84
|
|
|
def get_tau_trajectory(self, collection_passes): |
|
85
|
|
|
if 'tau' in self._traj_def: |
|
86
|
|
|
return self._create_trajectory('tau', collection_passes) |
|
87
|
|
|
return None |
|
88
|
|
|
|
|
89
|
|
|
def set_alpha_iters_trajectory(self, nb_document_passes): |
|
90
|
|
|
if 'alpha_iter' in self._traj_def: |
|
91
|
|
|
self._regularizer.alpha_iter = list(self._create_trajectory('alpha_iter', nb_document_passes)) |
|
92
|
|
|
elif self._alpha_iter_scalar: |
|
93
|
|
|
self._regularizer.alpha_iter = [self._alpha_iter_scalar] * nb_document_passes |
|
94
|
|
|
|
|
95
|
|
|
def _create_trajectory(self, name, length): |
|
96
|
|
|
_ = self._traj_def[name].split('_') |
|
97
|
|
|
return self.trajectory_builder.begin_trajectory('tau')\ |
|
98
|
|
|
.deactivate(int(_[0]))\ |
|
99
|
|
|
.interpolate_to(length - int(_[0]), float(_[3]), interpolation=_[1], start=float(_[2]))\ |
|
100
|
|
|
.create() |
|
101
|
|
|
|
|
102
|
|
|
@property |
|
103
|
|
|
def static_parameters(self): |
|
104
|
|
|
return self._reg_constr_params |
|
105
|
|
|
|
|
106
|
|
|
@property |
|
107
|
|
|
def artm_regularizer(self): |
|
108
|
|
|
return self._regularizer |
|
109
|
|
|
def __str__(self): |
|
110
|
|
|
return self.name |
|
111
|
|
|
@property |
|
112
|
|
|
def name(self): |
|
113
|
|
|
return self._name |
|
114
|
|
|
|
|
115
|
|
|
@property |
|
116
|
|
|
def type(self): |
|
117
|
|
|
for k, v in ArtmRegularizerWrapper.subclasses.items(): |
|
118
|
|
|
if type(self) == v: |
|
119
|
|
|
return k |
|
120
|
|
|
@property |
|
121
|
|
|
def long_type(self): |
|
122
|
|
|
return self._long_type |
|
123
|
|
|
|
|
124
|
|
|
class SmoothSparseRegularizerWrapper(ArtmRegularizerWrapper): |
|
125
|
|
|
__metaclass__ = abc.ABCMeta |
|
126
|
|
|
|
|
127
|
|
|
def __init__(self, params_dict, targeted_topics): |
|
128
|
|
|
""" |
|
129
|
|
|
:param params_dict: |
|
130
|
|
|
:param targeted_topics: |
|
131
|
|
|
""" |
|
132
|
|
|
if len(targeted_topics) == 0: |
|
133
|
|
|
logger.warning("Did not specify topics to target with the '{}' {} regularizer. By default the Smooth regularizer will target all topics. " |
|
134
|
|
|
"This is recommended only if all your regularizers target all topics (no notion of background-domain separation). " |
|
135
|
|
|
"If you are modeling an LDA (plsa_formula + smoothing regularization over all topics), ignore this warning.".format(params_dict['name'], type(self).__name__)) |
|
136
|
|
|
targeted_topics = None |
|
137
|
|
|
super(SmoothSparseRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': targeted_topics})) |
|
138
|
|
|
|
|
139
|
|
|
|
|
140
|
|
|
class SmoothSparsePhiRegularizerWrapper(SmoothSparseRegularizerWrapper): |
|
141
|
|
|
_artm_constructor = artm.SmoothSparsePhiRegularizer |
|
142
|
|
|
def __init__(self, params_dict, topic_names, class_ids): |
|
143
|
|
|
super(SmoothSparsePhiRegularizerWrapper, self).__init__(dict(params_dict, **{'class_ids': class_ids}), topic_names) |
|
144
|
|
|
|
|
145
|
|
|
@ArtmRegularizerWrapper.register_subclass('sparse-phi') |
|
146
|
|
|
class SparsePhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper): |
|
147
|
|
|
def __init__(self, params_dict, topic_names, class_ids): |
|
148
|
|
|
super(SparsePhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids) |
|
149
|
|
|
|
|
150
|
|
|
@ArtmRegularizerWrapper.register_subclass('smooth-phi') |
|
151
|
|
|
class SmoothPhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper): |
|
152
|
|
|
def __init__(self, params_dict, topic_names, class_ids): |
|
153
|
|
|
super(SmoothPhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids) |
|
154
|
|
|
|
|
155
|
|
|
|
|
156
|
|
|
class SmoothSparseThetaRegularizerWrapper(SmoothSparseRegularizerWrapper): |
|
157
|
|
|
_artm_constructor = artm.SmoothSparseThetaRegularizer |
|
158
|
|
|
def __init__(self, params_dict, topic_names): |
|
159
|
|
|
super(SmoothSparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names) |
|
160
|
|
|
|
|
161
|
|
|
|
|
162
|
|
|
@ArtmRegularizerWrapper.register_subclass('sparse-theta') |
|
163
|
|
|
class SparseThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper): |
|
164
|
|
|
def __init__(self, params_dict, topic_names): |
|
165
|
|
|
super(SparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names) |
|
166
|
|
|
|
|
167
|
|
|
@ArtmRegularizerWrapper.register_subclass('smooth-theta') |
|
168
|
|
|
class SmoothThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper): |
|
169
|
|
|
def __init__(self, params_dict, topic_names): |
|
170
|
|
|
super(SmoothThetaRegularizerWrapper, self).__init__(params_dict, topic_names) |
|
171
|
|
|
|
|
172
|
|
|
@ArtmRegularizerWrapper.register_subclass('label-regularization-phi') # can be used to expand the probability space to DxWxTxC eg author-topic model |
|
173
|
|
|
class DocumentClassificationRegularizerWrapper(ArtmRegularizerWrapper): |
|
174
|
|
|
_artm_constructor = artm.LabelRegularizationPhiRegularizer |
|
175
|
|
|
def __init__(self, params_dict, topic_names, dictionary=None, class_ids=None): |
|
176
|
|
|
""" |
|
177
|
|
|
:param str name: |
|
178
|
|
|
:param dict params_dict: Can contain keys: 'tau', 'gamma', 'dictionary' |
|
179
|
|
|
:param list of str topic_names: list of names of topics to regularize, will regularize all topics if not specified. |
|
180
|
|
|
Should correspond to the domain topics |
|
181
|
|
|
:param list of str class_ids: class_ids to regularize, will regularize all classes if not specified |
|
182
|
|
|
:param dictionary: |
|
183
|
|
|
:param class_ids: |
|
184
|
|
|
""" |
|
185
|
|
|
if len(topic_names) == 0: # T.O.D.O below: the warning should fire if smooth is active because then there must be defined |
|
186
|
|
|
# non overlapping sets of 'domain' and 'background' topics |
|
187
|
|
|
warn("Set DocumentClassificationRegularizer to target all topics. This is valid only if you do use 'background topics'.") |
|
188
|
|
|
topic_names = None |
|
189
|
|
|
super(DocumentClassificationRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': topic_names, |
|
190
|
|
|
'dictionary': dictionary, |
|
191
|
|
|
'class_ids': class_ids})) |
|
192
|
|
|
@ArtmRegularizerWrapper.register_subclass('decorrelate-phi') |
|
193
|
|
|
class PhiDecorrelator(ArtmRegularizerWrapper): |
|
194
|
|
|
_artm_constructor = artm.DecorrelatorPhiRegularizer |
|
195
|
|
|
def __init__(self, params_dict, topic_names, class_ids=None): |
|
196
|
|
|
super(PhiDecorrelator, self).__init__(dict(params_dict, **{'topic_names': topic_names, |
|
197
|
|
|
'class_ids': class_ids})) |
|
198
|
|
|
|
|
199
|
|
|
@ArtmRegularizerWrapper.register_subclass('improve-coherence') |
|
200
|
|
|
class ImproveCoherence(ArtmRegularizerWrapper): |
|
201
|
|
|
_artm_constructor = artm.ImproveCoherencePhiRegularizer # name=None, tau=1.0, class_ids=None, topic_names=None, dictionary=None, config=None) |
|
202
|
|
|
def __init__(self, params_dict, topic_names, dictionary, class_ids=None): |
|
203
|
|
|
super(ImproveCoherence, self).__init__(dict(params_dict, **{'dictionary': dictionary, |
|
204
|
|
|
'topic_names': topic_names, |
|
205
|
|
|
'class_ids': class_ids})) |
|
206
|
|
|
|