ArtmRegularizerWrapper.type()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
import abc
2
from warnings import warn
3
4
import artm
5
6
from .trajectory import TrajectoryBuilder
7
8
9
import logging
10
logger = logging.getLogger(__name__)
11
12
class ArtmRegularizerWrapper(object):
13
    __metaclass__ = abc.ABCMeta
14
    subclasses = {}
15
    labeling_ordering = ('tau', 'alpha_iter')
16
    _traj_type2traj_def_creator = {'alpha_iter': lambda x: '0_' + x[1],
17
                              'tau': lambda x: '{}_{}'.format(x[0], x[1])}
18
19
    def __init__(self, parameters_dict, verbose=False):
20
        self.trajectory_builder = TrajectoryBuilder()
21
        self._regularizer = None
22
        self._alpha_iter_scalar = None
23
        self._trajectory_lambdas = {}
24
        self._traj_def = {}
25
        self._reg_constr_params = {}
26
        self._params_for_labeling = {}
27
        self._start = None
28
        self._name = parameters_dict.pop('name', 'no-name')
29
        self._long_type = parameters_dict.pop('long-type', 'type-not-found')
30
31
        if 'start' in parameters_dict:
32
            self._start = parameters_dict['start']
33
        elif type(parameters_dict['tau']) == str:
34
            try:
35
                _ = float(parameters_dict['tau'])
36
            except ValueError:
37
                self._start = parameters_dict['tau'].split('_')[0]
38
                parameters_dict['tau'] = '_'.join(parameters_dict['tau'].split('_')[1:])
39
40
        for k, v in parameters_dict.items():
41
            # in case it is none then parameter it is handled by the default behaviour of artm
42
            if v is None or type(v) == list or type(v) == artm.dictionary.Dictionary or k == 'class_ids':  # one of topic_names or class_ids or dictionary
43
                self._reg_constr_params[k] = v
44
            else:
45
                try:  # by this point v should be a string, if exception occurs is shoud be only if a trajectory is defined (eg 'linear_-2_-10')
46
                    vf = float(v)
47
                    if k == 'alpha_iter':
48
                        self._alpha_iter_scalar = vf  # case: alpha_iter = a constant scalar which will be used for each of the 'nb_doument_passes' iterations
49
                    else:
50
                        self._reg_constr_params[k] = vf # case: parameter_name == 'tau'
51
                    self._params_for_labeling[k] = vf
52
                except ValueError:
53
                    if self._start is None:
54
                        print("INFO Defaulting to activating the regularizer from the 1st iteration")
55
                        self._start = 0
56
                    self._traj_def[k] = self._traj_type2traj_def_creator[k]([self._start, v])  # case: parameter_value is a trajectory definition without the 'start' setting (nb of initial iterations that regularizer stays inactive)
57
                    self._params_for_labeling[k] = self._traj_def[k]
58
        self._create_artm_regularizer(dict(self._reg_constr_params, **{'name': self._name}))
59
60
    def _create_artm_regularizer(self, parameters):
61
        self._regularizer = self._artm_constructor(**parameters)
62
        logger.info("Built '{}'/'{}' reg, named '{}', with settings: {}".format(self.type, self._long_type, self._name, '{'+', '.join(map(lambda x: '{}={}'.format(x[0], x[1]), parameters.items()))+'}'))
63
64
    @classmethod
65
    def register_subclass(cls, regularizer_type):
66
        def decorator(subclass):
67
            cls.subclasses[regularizer_type] = subclass
68
            return subclass
69
        return decorator
70
71
    @classmethod
72
    def create(cls, regularizer_type, *args, **kwargs):
73
        if regularizer_type not in cls.subclasses:
74
            raise ValueError("Bad regularizer type '{}'".format(regularizer_type))
75
        return cls.subclasses[regularizer_type](*args, **kwargs)
76
77
    @property
78
    def label(self):
79
        return '{}|{}'.format(self._name, '|'.join(map(lambda x: '{}:{}'.format(x[0][0], x[1]), self._get_labeling_data())))
80
81
    def _get_labeling_data(self):
82
        return sorted(self._params_for_labeling.items(), key=lambda x: x[0])
83
84
    def get_tau_trajectory(self, collection_passes):
85
        if 'tau' in self._traj_def:
86
            return self._create_trajectory('tau', collection_passes)
87
        return None
88
89
    def set_alpha_iters_trajectory(self, nb_document_passes):
90
        if 'alpha_iter' in self._traj_def:
91
            self._regularizer.alpha_iter = list(self._create_trajectory('alpha_iter', nb_document_passes))
92
        elif self._alpha_iter_scalar:
93
            self._regularizer.alpha_iter = [self._alpha_iter_scalar] * nb_document_passes
94
95
    def _create_trajectory(self, name, length):
96
        _ = self._traj_def[name].split('_')
97
        return self.trajectory_builder.begin_trajectory('tau')\
98
            .deactivate(int(_[0]))\
99
            .interpolate_to(length - int(_[0]), float(_[3]), interpolation=_[1], start=float(_[2]))\
100
            .create()
101
102
    @property
103
    def static_parameters(self):
104
        return self._reg_constr_params
105
106
    @property
107
    def artm_regularizer(self):
108
        return self._regularizer
109
    def __str__(self):
110
        return self.name
111
    @property
112
    def name(self):
113
        return self._name
114
115
    @property
116
    def type(self):
117
        for k, v in ArtmRegularizerWrapper.subclasses.items():
118
            if type(self) == v:
119
                return k
120
    @property
121
    def long_type(self):
122
        return self._long_type
123
124
class SmoothSparseRegularizerWrapper(ArtmRegularizerWrapper):
125
    __metaclass__ = abc.ABCMeta
126
127
    def __init__(self, params_dict, targeted_topics):
128
        """
129
        :param params_dict:
130
        :param targeted_topics:
131
        """
132
        if len(targeted_topics) == 0:
133
            logger.warning("Did not specify topics to target with the '{}' {} regularizer. By default the Smooth regularizer will target all topics. "
134
                           "This is recommended only if all your regularizers target all topics (no notion of background-domain separation). "
135
                           "If you are modeling an LDA (plsa_formula + smoothing regularization over all topics), ignore this warning.".format(params_dict['name'], type(self).__name__))
136
            targeted_topics = None
137
        super(SmoothSparseRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': targeted_topics}))
138
139
140
class SmoothSparsePhiRegularizerWrapper(SmoothSparseRegularizerWrapper):
141
    _artm_constructor = artm.SmoothSparsePhiRegularizer
142
    def __init__(self, params_dict, topic_names, class_ids):
143
        super(SmoothSparsePhiRegularizerWrapper, self).__init__(dict(params_dict, **{'class_ids': class_ids}), topic_names)
144
145
@ArtmRegularizerWrapper.register_subclass('sparse-phi')
146
class SparsePhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper):
147
    def __init__(self, params_dict, topic_names, class_ids):
148
        super(SparsePhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids)
149
150
@ArtmRegularizerWrapper.register_subclass('smooth-phi')
151
class SmoothPhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper):
152
    def __init__(self, params_dict, topic_names, class_ids):
153
        super(SmoothPhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids)
154
155
156
class SmoothSparseThetaRegularizerWrapper(SmoothSparseRegularizerWrapper):
157
    _artm_constructor = artm.SmoothSparseThetaRegularizer
158
    def __init__(self, params_dict, topic_names):
159
        super(SmoothSparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names)
160
161
162
@ArtmRegularizerWrapper.register_subclass('sparse-theta')
163
class SparseThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper):
164
    def __init__(self, params_dict, topic_names):
165
        super(SparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names)
166
167
@ArtmRegularizerWrapper.register_subclass('smooth-theta')
168
class SmoothThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper):
169
    def __init__(self, params_dict, topic_names):
170
        super(SmoothThetaRegularizerWrapper, self).__init__(params_dict, topic_names)
171
172
@ArtmRegularizerWrapper.register_subclass('label-regularization-phi')  # can be used to expand the probability space to DxWxTxC eg author-topic model
173
class DocumentClassificationRegularizerWrapper(ArtmRegularizerWrapper):
174
    _artm_constructor = artm.LabelRegularizationPhiRegularizer
175
    def __init__(self, params_dict, topic_names, dictionary=None, class_ids=None):
176
        """
177
        :param str name:
178
        :param dict params_dict: Can contain keys: 'tau', 'gamma', 'dictionary'
179
        :param list of str topic_names: list of names of topics to regularize, will regularize all topics if not specified.
180
            Should correspond to the domain topics
181
        :param list of str class_ids: class_ids to regularize, will regularize all classes if not specified
182
        :param dictionary:
183
        :param class_ids:
184
        """
185
        if len(topic_names) == 0: # T.O.D.O below: the warning should fire if smooth is active because then there must be defined
186
            # non overlapping sets of 'domain' and 'background' topics
187
            warn("Set DocumentClassificationRegularizer to target all topics. This is valid only if you do use 'background topics'.")
188
            topic_names = None
189
        super(DocumentClassificationRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': topic_names,
190
                                                                                            'dictionary': dictionary,
191
                                                                                            'class_ids': class_ids}))
192
@ArtmRegularizerWrapper.register_subclass('decorrelate-phi')
193
class PhiDecorrelator(ArtmRegularizerWrapper):
194
    _artm_constructor = artm.DecorrelatorPhiRegularizer
195
    def __init__(self, params_dict, topic_names, class_ids=None):
196
        super(PhiDecorrelator, self).__init__(dict(params_dict, **{'topic_names': topic_names,
197
                                                                   'class_ids': class_ids}))
198
199
@ArtmRegularizerWrapper.register_subclass('improve-coherence')
200
class ImproveCoherence(ArtmRegularizerWrapper):
201
    _artm_constructor = artm.ImproveCoherencePhiRegularizer # name=None, tau=1.0, class_ids=None, topic_names=None, dictionary=None, config=None)
202
    def __init__(self, params_dict, topic_names, dictionary, class_ids=None):
203
        super(ImproveCoherence, self).__init__(dict(params_dict, **{'dictionary': dictionary,
204
                                                                    'topic_names': topic_names,
205
                                                                    'class_ids': class_ids}))
206