Completed
Push — master ( 2006d6...a6ef36 )
by Osma
13s queued 11s
created

annif.backend.vw_multi   B

Complexity

Total Complexity 46

Size/Duplication

Total Lines 238
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 46
eloc 194
dl 0
loc 238
rs 8.72
c 0
b 0
f 0

20 Methods

Rating   Name   Duplication   Size   Complexity  
A VWMultiBackend._create_train_file() 0 7 1
A VWMultiBackend._suggest_chunks() 0 16 4
A VWMultiBackend._create_params() 0 9 1
A VWMultiBackend._inputs_to_exampletext() 0 10 4
A VWMultiBackend._normalize_text() 0 3 1
A VWMultiBackend.inputs() 0 4 1
A VWMultiBackend._create_model() 0 13 2
A VWMultiBackend.initialize() 0 15 4
A VWMultiBackend._create_examples() 0 9 3
A VWMultiBackend._convert_result() 0 14 3
A VWMultiBackend._get_input() 0 10 2
A VWMultiBackend._cleanup_text() 0 4 1
A VWMultiBackend.learn() 0 6 2
A VWMultiBackend._write_train_file() 0 5 3
A VWMultiBackend.train() 0 4 1
A VWMultiBackend._convert_param() 0 14 4
A VWMultiBackend._uris_to_subject_ids() 0 7 3
A VWMultiBackend.algorithm() 0 9 2
A VWMultiBackend._format_examples() 0 7 3
A VWMultiBackend.default_params() 0 8 1

How to fix   Complexity   

Complexity

Complex classes like annif.backend.vw_multi often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Annif backend using the Vowpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import os
5
import random
6
import numpy as np
7
from vowpalwabbit import pyvw
8
import annif.project
9
from annif.suggestion import ListSuggestionResult, VectorSuggestionResult
10
from annif.exception import ConfigurationException
11
from annif.exception import NotInitializedException
12
from . import backend
13
from . import mixins
14
15
16
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifLearningBackend):
17
    """Vowpal Wabbit multiclass/multilabel backend for Annif"""
18
19
    name = "vw_multi"
20
    needs_subject_index = True
21
22
    MODEL_FILE = 'vw-model'
23
    TRAIN_FILE = 'vw-train.txt'
24
25
    # defaults for uninitialized instances
26
    _model = None
27
28
    VW_PARAMS = {
29
        'bit_precision': (int, None),
30
        'ngram': (lambda x: '_{}'.format(int(x)), None),
31
        'learning_rate': (float, None),
32
        'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
33
        'l1': (float, None),
34
        'l2': (float, None),
35
        'passes': (int, None),
36
        'probabilities': (bool, None)
37
    }
38
39
    SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')
40
41
    DEFAULT_INPUTS = '_text_'
42
43
    DEFAULT_PARAMS = {'algorithm': 'oaa'}
44
45
    def initialize(self):
46
        if self._model is None:
47
            path = os.path.join(self.datadir, self.MODEL_FILE)
48
            if not os.path.exists(path):
49
                raise NotInitializedException(
50
                    'model {} not found'.format(path),
51
                    backend_id=self.backend_id)
52
            self.debug('loading VW model from {}'.format(path))
53
            params = self._create_params({'i': path, 'quiet': True})
54
            if 'passes' in params:
55
                # don't confuse the model with passes
56
                del params['passes']
57
            self.debug("model parameters: {}".format(params))
58
            self._model = pyvw.vw(**params)
59
            self.debug('loaded model {}'.format(str(self._model)))
60
61
    def _convert_param(self, param, val):
62
        pspec, _ = self.VW_PARAMS[param]
63
        if isinstance(pspec, list):
64
            if val in pspec:
65
                return val
66
            raise ConfigurationException(
67
                "{} is not a valid value for {} (allowed: {})".format(
68
                    val, param, ', '.join(pspec)), backend_id=self.backend_id)
69
        try:
70
            return pspec(val)
71
        except ValueError:
72
            raise ConfigurationException(
73
                "The {} value {} cannot be converted to {}".format(
74
                    param, val, pspec), backend_id=self.backend_id)
75
76
    def _create_params(self, params):
77
        params = params.copy()  # don't mutate the original dict
78
        params.update({param: defaultval
79
                       for param, (_, defaultval) in self.VW_PARAMS.items()
80
                       if defaultval is not None})
81
        params.update({param: self._convert_param(param, val)
82
                       for param, val in self.params.items()
83
                       if param in self.VW_PARAMS})
84
        return params
85
86
    def default_params(self):
87
        params = backend.AnnifBackend.DEFAULT_PARAMS.copy()
88
        params.update(mixins.ChunkingBackend.DEFAULT_PARAMS)
89
        params.update(self.DEFAULT_PARAMS)
90
        params.update({param: default_val
91
                       for param, (_, default_val) in self.VW_PARAMS.items()
92
                       if default_val is not None})
93
        return params
94
95
    @property
96
    def algorithm(self):
97
        algorithm = self.params['algorithm']
98
        if algorithm not in self.SUPPORTED_ALGORITHMS:
99
            raise ConfigurationException(
100
                "{} is not a valid algorithm (allowed: {})".format(
101
                    algorithm, ', '.join(self.SUPPORTED_ALGORITHMS)),
102
                backend_id=self.backend_id)
103
        return algorithm
104
105
    @property
106
    def inputs(self):
107
        inputs = self.params.get('inputs', self.DEFAULT_INPUTS)
108
        return inputs.split(',')
109
110
    @staticmethod
111
    def _cleanup_text(text):
112
        # colon and pipe chars have special meaning in VW and must be avoided
113
        return text.replace(':', '').replace('|', '')
114
115
    def _normalize_text(self, text):
116
        ntext = ' '.join(self.project.analyzer.tokenize_words(text))
117
        return VWMultiBackend._cleanup_text(ntext)
118
119
    def _uris_to_subject_ids(self, uris):
120
        subject_ids = []
121
        for uri in uris:
122
            subject_id = self.project.subjects.by_uri(uri)
123
            if subject_id is not None:
124
                subject_ids.append(subject_id)
125
        return subject_ids
126
127
    def _format_examples(self, text, uris):
128
        subject_ids = self._uris_to_subject_ids(uris)
129
        if self.algorithm == 'multilabel_oaa':
130
            yield '{} {}'.format(','.join(map(str, subject_ids)), text)
131
        else:
132
            for subject_id in subject_ids:
133
                yield '{} {}'.format(subject_id + 1, text)
134
135
    def _get_input(self, input, text):
136
        if input == '_text_':
137
            return self._normalize_text(text)
138
        else:
139
            proj = annif.project.get_project(input)
140
            result = proj.suggest(text)
141
            features = [
142
                '{}:{}'.format(self._cleanup_text(hit.uri), hit.score)
143
                for hit in result.hits]
144
            return ' '.join(features)
145
146
    def _inputs_to_exampletext(self, text):
147
        namespaces = {}
148
        for input in self.inputs:
149
            inputtext = self._get_input(input, text)
150
            if inputtext:
151
                namespaces[input] = inputtext
152
        if not namespaces:
153
            return None
154
        return ' '.join(['|{} {}'.format(namespace, featurestr)
155
                         for namespace, featurestr in namespaces.items()])
156
157
    def _create_examples(self, corpus):
158
        examples = []
159
        for doc in corpus.documents:
160
            text = self._inputs_to_exampletext(doc.text)
161
            if not text:
162
                continue
163
            examples.extend(self._format_examples(text, doc.uris))
164
        random.shuffle(examples)
165
        return examples
166
167
    def _create_model(self):
168
        self.info('creating VW model (algorithm: {})'.format(self.algorithm))
169
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
170
        initial_params = {'data': trainpath,
171
                          self.algorithm: len(self.project.subjects)}
172
        params = self._create_params(initial_params)
173
        if params.get('passes', 1) > 1:
174
            # need a cache file when there are multiple passes
175
            params.update({'cache': True, 'kill_cache': True})
176
        self.debug("model parameters: {}".format(params))
177
        self._model = pyvw.vw(**params)
178
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
179
        self._model.save(modelpath)
180
181
    def _convert_result(self, result):
182
        if self.algorithm == 'multilabel_oaa':
183
            # result is a list of subject IDs - need to vectorize
184
            mask = np.zeros(len(self.project.subjects), dtype=np.float32)
185
            mask[result] = 1.0
186
            return mask
187
        elif isinstance(result, int):
188
            # result is a single integer - need to one-hot-encode
189
            mask = np.zeros(len(self.project.subjects), dtype=np.float32)
190
            mask[result - 1] = 1.0
191
            return mask
192
        else:
193
            # result is a list of scores (probabilities or binary 1/0)
194
            return np.array(result, dtype=np.float32)
195
196
    def _suggest_chunks(self, chunktexts):
197
        results = []
198
        for chunktext in chunktexts:
199
200
            exampletext = self._inputs_to_exampletext(chunktext)
201
            if not exampletext:
202
                continue
203
            example = ' {}'.format(exampletext)
204
            result = self._model.predict(example)
205
            results.append(self._convert_result(result))
206
        if not results:  # empty result
207
            return ListSuggestionResult(
208
                hits=[], subject_index=self.project.subjects)
209
        return VectorSuggestionResult(
210
            np.array(results, dtype=np.float32).mean(axis=0),
211
            self.project.subjects)
212
213
    @staticmethod
214
    def _write_train_file(examples, filename):
215
        with open(filename, 'w', encoding='utf-8') as trainfile:
216
            for ex in examples:
217
                print(ex, file=trainfile)
218
219
    def _create_train_file(self, corpus):
220
        self.info('creating VW train file')
221
        examples = self._create_examples(corpus)
222
        annif.util.atomic_save(examples,
223
                               self.datadir,
224
                               self.TRAIN_FILE,
225
                               method=self._write_train_file)
226
227
    def train(self, corpus):
228
        self.info("creating VW model")
229
        self._create_train_file(corpus)
230
        self._create_model()
231
232
    def learn(self, corpus):
233
        self.initialize()
234
        for example in self._create_examples(corpus):
235
            self._model.learn(example)
236
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
237
        self._model.save(modelpath)
238