Passed
Pull Request — master (#256)
by Osma
03:08
created

VWMultiBackend._inputs_to_exampletext()   A

Complexity

Conditions 4

Size

Total Lines 10
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 10
rs 9.9
c 0
b 0
f 0
cc 4
nop 3
1
"""Annif backend using the Vorpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import random
5
import os.path
6
import annif.util
7
from vowpalwabbit import pyvw
8
import numpy as np
9
from annif.hit import ListAnalysisResult, VectorAnalysisResult
10
from annif.exception import ConfigurationException, NotInitializedException
11
from . import backend
12
from . import mixins
13
14
15
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
16
    """Vorpal Wabbit multiclass/multilabel backend for Annif"""
17
18
    name = "vw_multi"
19
    needs_subject_index = True
20
21
    VW_PARAMS = {
22
        # each param specifier is a pair (allowed_values, default_value)
23
        # where allowed_values is either a type or a list of allowed values
24
        # and default_value may be None, to let VW decide by itself
25
        'bit_precision': (int, None),
26
        'ngram': (lambda x: '_{}'.format(int(x)), None),
27
        'learning_rate': (float, None),
28
        'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
29
        'l1': (float, None),
30
        'l2': (float, None),
31
        'passes': (int, None),
32
        'probabilities': (bool, None)
33
    }
34
35
    DEFAULT_ALGORITHM = 'oaa'
36
    SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')
37
38
    DEFAULT_INPUTS = '_text_'
39
40
    MODEL_FILE = 'vw-model'
41
    TRAIN_FILE = 'vw-train.txt'
42
43
    # defaults for uninitialized instances
44
    _model = None
45
46
    def initialize(self):
47
        if self._model is None:
48
            path = os.path.join(self._get_datadir(), self.MODEL_FILE)
49
            if not os.path.exists(path):
50
                raise NotInitializedException(
51
                    'model {} not found'.format(path),
52
                    backend_id=self.backend_id)
53
            self.debug('loading VW model from {}'.format(path))
54
            params = self._create_params({'i': path, 'quiet': True})
55
            if 'passes' in params:
56
                # don't confuse the model with passes
57
                del params['passes']
58
            self.debug("model parameters: {}".format(params))
59
            self._model = pyvw.vw(**params)
60
            self.debug('loaded model {}'.format(str(self._model)))
61
62
    @property
63
    def algorithm(self):
64
        algorithm = self.params.get('algorithm', self.DEFAULT_ALGORITHM)
65
        if algorithm not in self.SUPPORTED_ALGORITHMS:
66
            raise ConfigurationException(
67
                "{} is not a valid algorithm (allowed: {})".format(
68
                    algorithm, ', '.join(self.SUPPORTED_ALGORITHMS)),
69
                backend_id=self.backend_id)
70
        return algorithm
71
72
    @property
73
    def inputs(self):
74
        inputs = self.params.get('inputs', self.DEFAULT_INPUTS)
75
        return inputs.split(',')
76
77
    @staticmethod
78
    def _cleanup_text(text):
79
        # colon and pipe chars have special meaning in VW and must be avoided
80
        return text.replace(':', '').replace('|', '')
81
82
    @staticmethod
83
    def _normalize_text(project, text):
84
        ntext = ' '.join(project.analyzer.tokenize_words(text))
85
        return VWMultiBackend._cleanup_text(ntext)
86
87
    @staticmethod
88
    def _write_train_file(examples, filename):
89
        with open(filename, 'w') as trainfile:
90
            for ex in examples:
91
                print(ex, file=trainfile)
92
93
    @staticmethod
94
    def _uris_to_subject_ids(project, uris):
95
        subject_ids = []
96
        for uri in uris:
97
            subject_id = project.subjects.by_uri(uri)
98
            if subject_id is not None:
99
                subject_ids.append(subject_id)
100
        return subject_ids
101
102
    def _format_examples(self, project, text, uris):
103
        subject_ids = self._uris_to_subject_ids(project, uris)
104
        if self.algorithm == 'multilabel_oaa':
105
            yield '{} {}'.format(','.join(map(str, subject_ids)), text)
106
        else:
107
            for subject_id in subject_ids:
108
                yield '{} {}'.format(subject_id + 1, text)
109
110
    def _get_input(self, input, project, text):
111
        if input == '_text_':
112
            return self._normalize_text(project, text)
113
        else:
114
            proj = annif.project.get_project(input)
115
            result = proj.analyze(text)
116
            features = [
117
                '{}:{}'.format(self._cleanup_text(hit.uri), hit.score)
118
                for hit in result.hits]
119
            return ' '.join(features)
120
121
    def _inputs_to_exampletext(self, project, text):
122
        namespaces = {}
123
        for input in self.inputs:
124
            inputtext = self._get_input(input, project, text)
125
            if inputtext:
126
                namespaces[input] = inputtext
127
        if not namespaces:
128
            return None
129
        return ' '.join(['|{} {}'.format(namespace, featurestr)
130
                         for namespace, featurestr in namespaces.items()])
131
132
    def _create_train_file(self, corpus, project):
133
        self.info('creating VW train file')
134
        examples = []
135
        for doc in corpus.documents:
136
            text = self._inputs_to_exampletext(project, doc.text)
137
            if not text:
138
                continue
139
            examples.extend(self._format_examples(project, text, doc.uris))
140
        random.shuffle(examples)
141
        annif.util.atomic_save(examples,
142
                               self._get_datadir(),
143
                               self.TRAIN_FILE,
144
                               method=self._write_train_file)
145
146
    def _convert_param(self, param, val):
147
        pspec, _ = self.VW_PARAMS[param]
148
        if isinstance(pspec, list):
149
            if val in pspec:
150
                return val
151
            raise ConfigurationException(
152
                "{} is not a valid value for {} (allowed: {})".format(
153
                    val, param, ', '.join(pspec)), backend_id=self.backend_id)
154
        try:
155
            return pspec(val)
156
        except ValueError:
157
            raise ConfigurationException(
158
                "The {} value {} cannot be converted to {}".format(
159
                    param, val, pspec), backend_id=self.backend_id)
160
161
    def _create_params(self, params):
162
        params.update({param: defaultval
163
                       for param, (_, defaultval) in self.VW_PARAMS.items()
164
                       if defaultval is not None})
165
        params.update({param: self._convert_param(param, val)
166
                       for param, val in self.params.items()
167
                       if param in self.VW_PARAMS})
168
        return params
169
170
    def _create_model(self, project):
171
        self.info('creating VW model (algorithm: {})'.format(self.algorithm))
172
        trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
173
        params = self._create_params(
174
            {'data': trainpath, self.algorithm: len(project.subjects)})
175
        if params.get('passes', 1) > 1:
176
            # need a cache file when there are multiple passes
177
            params.update({'cache': True, 'kill_cache': True})
178
        self.debug("model parameters: {}".format(params))
179
        self._model = pyvw.vw(**params)
180
        modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
181
        self._model.save(modelpath)
182
183
    def train(self, corpus, project):
184
        self._create_train_file(corpus, project)
185
        self._create_model(project)
186
187
    def _convert_result(self, result, project):
188
        if self.algorithm == 'multilabel_oaa':
189
            # result is a list of subject IDs - need to vectorize
190
            mask = np.zeros(len(project.subjects))
191
            mask[result] = 1.0
192
            return mask
193
        elif isinstance(result, int):
194
            # result is a single integer - need to one-hot-encode
195
            mask = np.zeros(len(project.subjects))
196
            mask[result - 1] = 1.0
197
            return mask
198
        else:
199
            # result is a list of scores (probabilities or binary 1/0)
200
            return np.array(result)
201
202
    def _analyze_chunks(self, chunktexts, project):
203
        results = []
204
        for chunktext in chunktexts:
205
            exampletext = self._inputs_to_exampletext(project, chunktext)
206
            if not exampletext:
207
                continue
208
            example = ' {}'.format(exampletext)
209
            result = self._model.predict(example)
210
            results.append(self._convert_result(result, project))
211
        if not results:  # empty result
212
            return ListAnalysisResult(hits=[], subject_index=project.subjects)
213
        return VectorAnalysisResult(
214
            np.array(results).mean(axis=0), project.subjects)
215