Passed
Push — master ( c8c370...dee89b )
by Osma
03:14
created

VWMultiBackend._suggest_chunks()   A

Complexity

Conditions 4

Size

Total Lines 14
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 14
rs 9.7
c 0
b 0
f 0
cc 4
nop 3
1
"""Annif backend using the Vowpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import random
5
import os.path
6
import annif.util
7
from vowpalwabbit import pyvw
8
import numpy as np
9
from annif.suggestion import ListSuggestionResult, VectorSuggestionResult
10
from annif.exception import ConfigurationException, NotInitializedException
11
from . import backend
12
from . import mixins
13
14
15
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifLearningBackend):
16
    """Vorpal Wabbit multiclass/multilabel backend for Annif"""
17
18
    name = "vw_multi"
19
    needs_subject_index = True
20
21
    VW_PARAMS = {
22
        # each param specifier is a pair (allowed_values, default_value)
23
        # where allowed_values is either a type or a list of allowed values
24
        # and default_value may be None, to let VW decide by itself
25
        'bit_precision': (int, None),
26
        'ngram': (lambda x: '_{}'.format(int(x)), None),
27
        'learning_rate': (float, None),
28
        'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
29
        'l1': (float, None),
30
        'l2': (float, None),
31
        'passes': (int, None),
32
        'probabilities': (bool, None)
33
    }
34
35
    DEFAULT_ALGORITHM = 'oaa'
36
    SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')
37
38
    DEFAULT_INPUTS = '_text_'
39
40
    MODEL_FILE = 'vw-model'
41
    TRAIN_FILE = 'vw-train.txt'
42
43
    # defaults for uninitialized instances
44
    _model = None
45
46
    def initialize(self):
47
        if self._model is None:
48
            path = os.path.join(self.datadir, self.MODEL_FILE)
49
            if not os.path.exists(path):
50
                raise NotInitializedException(
51
                    'model {} not found'.format(path),
52
                    backend_id=self.backend_id)
53
            self.debug('loading VW model from {}'.format(path))
54
            params = self._create_params({'i': path, 'quiet': True})
55
            if 'passes' in params:
56
                # don't confuse the model with passes
57
                del params['passes']
58
            self.debug("model parameters: {}".format(params))
59
            self._model = pyvw.vw(**params)
60
            self.debug('loaded model {}'.format(str(self._model)))
61
62
    @property
63
    def algorithm(self):
64
        algorithm = self.params.get('algorithm', self.DEFAULT_ALGORITHM)
65
        if algorithm not in self.SUPPORTED_ALGORITHMS:
66
            raise ConfigurationException(
67
                "{} is not a valid algorithm (allowed: {})".format(
68
                    algorithm, ', '.join(self.SUPPORTED_ALGORITHMS)),
69
                backend_id=self.backend_id)
70
        return algorithm
71
72
    @property
73
    def inputs(self):
74
        inputs = self.params.get('inputs', self.DEFAULT_INPUTS)
75
        return inputs.split(',')
76
77
    @staticmethod
78
    def _cleanup_text(text):
79
        # colon and pipe chars have special meaning in VW and must be avoided
80
        return text.replace(':', '').replace('|', '')
81
82
    @staticmethod
83
    def _normalize_text(project, text):
84
        ntext = ' '.join(project.analyzer.tokenize_words(text))
85
        return VWMultiBackend._cleanup_text(ntext)
86
87
    @staticmethod
88
    def _write_train_file(examples, filename):
89
        with open(filename, 'w', encoding='utf-8') as trainfile:
90
            for ex in examples:
91
                print(ex, file=trainfile)
92
93
    @staticmethod
94
    def _uris_to_subject_ids(project, uris):
95
        subject_ids = []
96
        for uri in uris:
97
            subject_id = project.subjects.by_uri(uri)
98
            if subject_id is not None:
99
                subject_ids.append(subject_id)
100
        return subject_ids
101
102
    def _format_examples(self, project, text, uris):
103
        subject_ids = self._uris_to_subject_ids(project, uris)
104
        if self.algorithm == 'multilabel_oaa':
105
            yield '{} {}'.format(','.join(map(str, subject_ids)), text)
106
        else:
107
            for subject_id in subject_ids:
108
                yield '{} {}'.format(subject_id + 1, text)
109
110
    def _get_input(self, input, project, text):
111
        if input == '_text_':
112
            return self._normalize_text(project, text)
113
        else:
114
            proj = annif.project.get_project(input)
115
            result = proj.suggest(text)
116
            features = [
117
                '{}:{}'.format(self._cleanup_text(hit.uri), hit.score)
118
                for hit in result.hits]
119
            return ' '.join(features)
120
121
    def _inputs_to_exampletext(self, project, text):
122
        namespaces = {}
123
        for input in self.inputs:
124
            inputtext = self._get_input(input, project, text)
125
            if inputtext:
126
                namespaces[input] = inputtext
127
        if not namespaces:
128
            return None
129
        return ' '.join(['|{} {}'.format(namespace, featurestr)
130
                         for namespace, featurestr in namespaces.items()])
131
132
    def _create_examples(self, corpus, project):
133
        examples = []
134
        for doc in corpus.documents:
135
            text = self._inputs_to_exampletext(project, doc.text)
136
            if not text:
137
                continue
138
            examples.extend(self._format_examples(project, text, doc.uris))
139
        random.shuffle(examples)
140
        return examples
141
142
    def _create_train_file(self, corpus, project):
143
        self.info('creating VW train file')
144
        examples = self._create_examples(corpus, project)
145
        annif.util.atomic_save(examples,
146
                               self.datadir,
147
                               self.TRAIN_FILE,
148
                               method=self._write_train_file)
149
150
    def _convert_param(self, param, val):
151
        pspec, _ = self.VW_PARAMS[param]
152
        if isinstance(pspec, list):
153
            if val in pspec:
154
                return val
155
            raise ConfigurationException(
156
                "{} is not a valid value for {} (allowed: {})".format(
157
                    val, param, ', '.join(pspec)), backend_id=self.backend_id)
158
        try:
159
            return pspec(val)
160
        except ValueError:
161
            raise ConfigurationException(
162
                "The {} value {} cannot be converted to {}".format(
163
                    param, val, pspec), backend_id=self.backend_id)
164
165
    def _create_params(self, params):
166
        params.update({param: defaultval
167
                       for param, (_, defaultval) in self.VW_PARAMS.items()
168
                       if defaultval is not None})
169
        params.update({param: self._convert_param(param, val)
170
                       for param, val in self.params.items()
171
                       if param in self.VW_PARAMS})
172
        return params
173
174
    def _create_model(self, project):
175
        self.info('creating VW model (algorithm: {})'.format(self.algorithm))
176
        trainpath = os.path.join(self.datadir, self.TRAIN_FILE)
177
        params = self._create_params(
178
            {'data': trainpath, self.algorithm: len(project.subjects)})
179
        if params.get('passes', 1) > 1:
180
            # need a cache file when there are multiple passes
181
            params.update({'cache': True, 'kill_cache': True})
182
        self.debug("model parameters: {}".format(params))
183
        self._model = pyvw.vw(**params)
184
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
185
        self._model.save(modelpath)
186
187
    def train(self, corpus, project):
188
        self._create_train_file(corpus, project)
189
        self._create_model(project)
190
191
    def learn(self, corpus, project):
192
        for example in self._create_examples(corpus, project):
193
            self._model.learn(example)
194
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
195
        self._model.save(modelpath)
196
197
    def _convert_result(self, result, project):
198
        if self.algorithm == 'multilabel_oaa':
199
            # result is a list of subject IDs - need to vectorize
200
            mask = np.zeros(len(project.subjects))
201
            mask[result] = 1.0
202
            return mask
203
        elif isinstance(result, int):
204
            # result is a single integer - need to one-hot-encode
205
            mask = np.zeros(len(project.subjects))
206
            mask[result - 1] = 1.0
207
            return mask
208
        else:
209
            # result is a list of scores (probabilities or binary 1/0)
210
            return np.array(result)
211
212
    def _suggest_chunks(self, chunktexts, project):
213
        results = []
214
        for chunktext in chunktexts:
215
            exampletext = self._inputs_to_exampletext(project, chunktext)
216
            if not exampletext:
217
                continue
218
            example = ' {}'.format(exampletext)
219
            result = self._model.predict(example)
220
            results.append(self._convert_result(result, project))
221
        if not results:  # empty result
222
            return ListSuggestionResult(
223
                hits=[], subject_index=project.subjects)
224
        return VectorSuggestionResult(
225
            np.array(results).mean(axis=0), project.subjects)
226