1
|
|
|
import abc |
2
|
|
|
from warnings import warn |
3
|
|
|
|
4
|
|
|
import artm |
5
|
|
|
|
6
|
|
|
from .trajectory import TrajectoryBuilder |
7
|
|
|
|
8
|
|
|
|
9
|
|
|
import logging |
10
|
|
|
logger = logging.getLogger(__name__) |
11
|
|
|
|
12
|
|
|
class ArtmRegularizerWrapper(object): |
13
|
|
|
__metaclass__ = abc.ABCMeta |
14
|
|
|
subclasses = {} |
15
|
|
|
labeling_ordering = ('tau', 'alpha_iter') |
16
|
|
|
_traj_type2traj_def_creator = {'alpha_iter': lambda x: '0_' + x[1], |
17
|
|
|
'tau': lambda x: '{}_{}'.format(x[0], x[1])} |
18
|
|
|
|
19
|
|
|
def __init__(self, parameters_dict, verbose=False): |
20
|
|
|
self.trajectory_builder = TrajectoryBuilder() |
21
|
|
|
self._regularizer = None |
22
|
|
|
self._alpha_iter_scalar = None |
23
|
|
|
self._trajectory_lambdas = {} |
24
|
|
|
self._traj_def = {} |
25
|
|
|
self._reg_constr_params = {} |
26
|
|
|
self._params_for_labeling = {} |
27
|
|
|
self._start = None |
28
|
|
|
self._name = parameters_dict.pop('name', 'no-name') |
29
|
|
|
self._long_type = parameters_dict.pop('long-type', 'type-not-found') |
30
|
|
|
|
31
|
|
|
if 'start' in parameters_dict: |
32
|
|
|
self._start = parameters_dict['start'] |
33
|
|
|
elif type(parameters_dict['tau']) == str: |
34
|
|
|
try: |
35
|
|
|
_ = float(parameters_dict['tau']) |
36
|
|
|
except ValueError: |
37
|
|
|
self._start = parameters_dict['tau'].split('_')[0] |
38
|
|
|
parameters_dict['tau'] = '_'.join(parameters_dict['tau'].split('_')[1:]) |
39
|
|
|
|
40
|
|
|
for k, v in parameters_dict.items(): |
41
|
|
|
# in case it is none then parameter it is handled by the default behaviour of artm |
42
|
|
|
if v is None or type(v) == list or type(v) == artm.dictionary.Dictionary or k == 'class_ids': # one of topic_names or class_ids or dictionary |
43
|
|
|
self._reg_constr_params[k] = v |
44
|
|
|
else: |
45
|
|
|
try: # by this point v should be a string, if exception occurs is shoud be only if a trajectory is defined (eg 'linear_-2_-10') |
46
|
|
|
vf = float(v) |
47
|
|
|
if k == 'alpha_iter': |
48
|
|
|
self._alpha_iter_scalar = vf # case: alpha_iter = a constant scalar which will be used for each of the 'nb_doument_passes' iterations |
49
|
|
|
else: |
50
|
|
|
self._reg_constr_params[k] = vf # case: parameter_name == 'tau' |
51
|
|
|
self._params_for_labeling[k] = vf |
52
|
|
|
except ValueError: |
53
|
|
|
if self._start is None: |
54
|
|
|
print("INFO Defaulting to activating the regularizer from the 1st iteration") |
55
|
|
|
self._start = 0 |
56
|
|
|
self._traj_def[k] = self._traj_type2traj_def_creator[k]([self._start, v]) # case: parameter_value is a trajectory definition without the 'start' setting (nb of initial iterations that regularizer stays inactive) |
57
|
|
|
self._params_for_labeling[k] = self._traj_def[k] |
58
|
|
|
self._create_artm_regularizer(dict(self._reg_constr_params, **{'name': self._name})) |
59
|
|
|
|
60
|
|
|
def _create_artm_regularizer(self, parameters): |
61
|
|
|
self._regularizer = self._artm_constructor(**parameters) |
62
|
|
|
logger.info("Built '{}'/'{}' reg, named '{}', with settings: {}".format(self.type, self._long_type, self._name, '{'+', '.join(map(lambda x: '{}={}'.format(x[0], x[1]), parameters.items()))+'}')) |
63
|
|
|
|
64
|
|
|
@classmethod |
65
|
|
|
def register_subclass(cls, regularizer_type): |
66
|
|
|
def decorator(subclass): |
67
|
|
|
cls.subclasses[regularizer_type] = subclass |
68
|
|
|
return subclass |
69
|
|
|
return decorator |
70
|
|
|
|
71
|
|
|
@classmethod |
72
|
|
|
def create(cls, regularizer_type, *args, **kwargs): |
73
|
|
|
if regularizer_type not in cls.subclasses: |
74
|
|
|
raise ValueError("Bad regularizer type '{}'".format(regularizer_type)) |
75
|
|
|
return cls.subclasses[regularizer_type](*args, **kwargs) |
76
|
|
|
|
77
|
|
|
@property |
78
|
|
|
def label(self): |
79
|
|
|
return '{}|{}'.format(self._name, '|'.join(map(lambda x: '{}:{}'.format(x[0][0], x[1]), self._get_labeling_data()))) |
80
|
|
|
|
81
|
|
|
def _get_labeling_data(self): |
82
|
|
|
return sorted(self._params_for_labeling.items(), key=lambda x: x[0]) |
83
|
|
|
|
84
|
|
|
def get_tau_trajectory(self, collection_passes): |
85
|
|
|
if 'tau' in self._traj_def: |
86
|
|
|
return self._create_trajectory('tau', collection_passes) |
87
|
|
|
return None |
88
|
|
|
|
89
|
|
|
def set_alpha_iters_trajectory(self, nb_document_passes): |
90
|
|
|
if 'alpha_iter' in self._traj_def: |
91
|
|
|
self._regularizer.alpha_iter = list(self._create_trajectory('alpha_iter', nb_document_passes)) |
92
|
|
|
elif self._alpha_iter_scalar: |
93
|
|
|
self._regularizer.alpha_iter = [self._alpha_iter_scalar] * nb_document_passes |
94
|
|
|
|
95
|
|
|
def _create_trajectory(self, name, length): |
96
|
|
|
_ = self._traj_def[name].split('_') |
97
|
|
|
return self.trajectory_builder.begin_trajectory('tau')\ |
98
|
|
|
.deactivate(int(_[0]))\ |
99
|
|
|
.interpolate_to(length - int(_[0]), float(_[3]), interpolation=_[1], start=float(_[2]))\ |
100
|
|
|
.create() |
101
|
|
|
|
102
|
|
|
@property |
103
|
|
|
def static_parameters(self): |
104
|
|
|
return self._reg_constr_params |
105
|
|
|
|
106
|
|
|
@property |
107
|
|
|
def artm_regularizer(self): |
108
|
|
|
return self._regularizer |
109
|
|
|
def __str__(self): |
110
|
|
|
return self.name |
111
|
|
|
@property |
112
|
|
|
def name(self): |
113
|
|
|
return self._name |
114
|
|
|
|
115
|
|
|
@property |
116
|
|
|
def type(self): |
117
|
|
|
for k, v in ArtmRegularizerWrapper.subclasses.items(): |
118
|
|
|
if type(self) == v: |
119
|
|
|
return k |
120
|
|
|
@property |
121
|
|
|
def long_type(self): |
122
|
|
|
return self._long_type |
123
|
|
|
|
124
|
|
|
class SmoothSparseRegularizerWrapper(ArtmRegularizerWrapper): |
125
|
|
|
__metaclass__ = abc.ABCMeta |
126
|
|
|
|
127
|
|
|
def __init__(self, params_dict, targeted_topics): |
128
|
|
|
""" |
129
|
|
|
:param params_dict: |
130
|
|
|
:param targeted_topics: |
131
|
|
|
""" |
132
|
|
|
if len(targeted_topics) == 0: |
133
|
|
|
logger.warning("Did not specify topics to target with the '{}' {} regularizer. By default the Smooth regularizer will target all topics. " |
134
|
|
|
"This is recommended only if all your regularizers target all topics (no notion of background-domain separation). " |
135
|
|
|
"If you are modeling an LDA (plsa_formula + smoothing regularization over all topics), ignore this warning.".format(params_dict['name'], type(self).__name__)) |
136
|
|
|
targeted_topics = None |
137
|
|
|
super(SmoothSparseRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': targeted_topics})) |
138
|
|
|
|
139
|
|
|
|
140
|
|
|
class SmoothSparsePhiRegularizerWrapper(SmoothSparseRegularizerWrapper): |
141
|
|
|
_artm_constructor = artm.SmoothSparsePhiRegularizer |
142
|
|
|
def __init__(self, params_dict, topic_names, class_ids): |
143
|
|
|
super(SmoothSparsePhiRegularizerWrapper, self).__init__(dict(params_dict, **{'class_ids': class_ids}), topic_names) |
144
|
|
|
|
145
|
|
|
@ArtmRegularizerWrapper.register_subclass('sparse-phi') |
146
|
|
|
class SparsePhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper): |
147
|
|
|
def __init__(self, params_dict, topic_names, class_ids): |
148
|
|
|
super(SparsePhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids) |
149
|
|
|
|
150
|
|
|
@ArtmRegularizerWrapper.register_subclass('smooth-phi') |
151
|
|
|
class SmoothPhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper): |
152
|
|
|
def __init__(self, params_dict, topic_names, class_ids): |
153
|
|
|
super(SmoothPhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids) |
154
|
|
|
|
155
|
|
|
|
156
|
|
|
class SmoothSparseThetaRegularizerWrapper(SmoothSparseRegularizerWrapper): |
157
|
|
|
_artm_constructor = artm.SmoothSparseThetaRegularizer |
158
|
|
|
def __init__(self, params_dict, topic_names): |
159
|
|
|
super(SmoothSparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names) |
160
|
|
|
|
161
|
|
|
|
162
|
|
|
@ArtmRegularizerWrapper.register_subclass('sparse-theta') |
163
|
|
|
class SparseThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper): |
164
|
|
|
def __init__(self, params_dict, topic_names): |
165
|
|
|
super(SparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names) |
166
|
|
|
|
167
|
|
|
@ArtmRegularizerWrapper.register_subclass('smooth-theta') |
168
|
|
|
class SmoothThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper): |
169
|
|
|
def __init__(self, params_dict, topic_names): |
170
|
|
|
super(SmoothThetaRegularizerWrapper, self).__init__(params_dict, topic_names) |
171
|
|
|
|
172
|
|
|
@ArtmRegularizerWrapper.register_subclass('label-regularization-phi') # can be used to expand the probability space to DxWxTxC eg author-topic model |
173
|
|
|
class DocumentClassificationRegularizerWrapper(ArtmRegularizerWrapper): |
174
|
|
|
_artm_constructor = artm.LabelRegularizationPhiRegularizer |
175
|
|
|
def __init__(self, params_dict, topic_names, dictionary=None, class_ids=None): |
176
|
|
|
""" |
177
|
|
|
:param str name: |
178
|
|
|
:param dict params_dict: Can contain keys: 'tau', 'gamma', 'dictionary' |
179
|
|
|
:param list of str topic_names: list of names of topics to regularize, will regularize all topics if not specified. |
180
|
|
|
Should correspond to the domain topics |
181
|
|
|
:param list of str class_ids: class_ids to regularize, will regularize all classes if not specified |
182
|
|
|
:param dictionary: |
183
|
|
|
:param class_ids: |
184
|
|
|
""" |
185
|
|
|
if len(topic_names) == 0: # T.O.D.O below: the warning should fire if smooth is active because then there must be defined |
186
|
|
|
# non overlapping sets of 'domain' and 'background' topics |
187
|
|
|
warn("Set DocumentClassificationRegularizer to target all topics. This is valid only if you do use 'background topics'.") |
188
|
|
|
topic_names = None |
189
|
|
|
super(DocumentClassificationRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': topic_names, |
190
|
|
|
'dictionary': dictionary, |
191
|
|
|
'class_ids': class_ids})) |
192
|
|
|
@ArtmRegularizerWrapper.register_subclass('decorrelate-phi') |
193
|
|
|
class PhiDecorrelator(ArtmRegularizerWrapper): |
194
|
|
|
_artm_constructor = artm.DecorrelatorPhiRegularizer |
195
|
|
|
def __init__(self, params_dict, topic_names, class_ids=None): |
196
|
|
|
super(PhiDecorrelator, self).__init__(dict(params_dict, **{'topic_names': topic_names, |
197
|
|
|
'class_ids': class_ids})) |
198
|
|
|
|
199
|
|
|
@ArtmRegularizerWrapper.register_subclass('improve-coherence') |
200
|
|
|
class ImproveCoherence(ArtmRegularizerWrapper): |
201
|
|
|
_artm_constructor = artm.ImproveCoherencePhiRegularizer # name=None, tau=1.0, class_ids=None, topic_names=None, dictionary=None, config=None) |
202
|
|
|
def __init__(self, params_dict, topic_names, dictionary, class_ids=None): |
203
|
|
|
super(ImproveCoherence, self).__init__(dict(params_dict, **{'dictionary': dictionary, |
204
|
|
|
'topic_names': topic_names, |
205
|
|
|
'class_ids': class_ids})) |
206
|
|
|
|