Passed
Pull Request — master (#254)
by Osma
02:34
created

VWMultiBackend._create_train_file()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 11
rs 9.85
c 0
b 0
f 0
cc 2
nop 3
1
"""Annif backend using the Vorpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import random
5
import os.path
6
import annif.util
7
from vowpalwabbit import pyvw
8
import numpy as np
9
from annif.hit import VectorAnalysisResult
10
from annif.exception import ConfigurationException, NotInitializedException
11
from . import backend
12
from . import mixins
13
14
15
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
16
    """Vorpal Wabbit multiclass/multilabel backend for Annif"""
17
18
    name = "vw_multi"
19
    needs_subject_index = True
20
21
    VW_PARAMS = {
22
        # each param specifier is a pair (allowed_values, default_value)
23
        # where allowed_values is either a type or a list of allowed values
24
        # and default_value may be None, to let VW decide by itself
25
        'bit_precision': (int, None),
26
        'ngram': (int, None),
27
        'learning_rate': (float, None),
28
        'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
29
        'l1': (float, None),
30
        'l2': (float, None),
31
        'passes': (int, None),
32
        'probabilities': (bool, None)
33
    }
34
35
    DEFAULT_ALGORITHM = 'oaa'
36
    SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')
37
38
    MODEL_FILE = 'vw-model'
39
    TRAIN_FILE = 'vw-train.txt'
40
41
    # defaults for uninitialized instances
42
    _model = None
43
44
    def initialize(self):
45
        if self._model is None:
46
            path = os.path.join(self._get_datadir(), self.MODEL_FILE)
47
            if not os.path.exists(path):
48
                raise NotInitializedException(
49
                    'model {} not found'.format(path),
50
                    backend_id=self.backend_id)
51
            self.debug('loading VW model from {}'.format(path))
52
            params = self._create_params({'i': path, 'quiet': True})
53
            if 'passes' in params:
54
                # don't confuse the model with passes
55
                del params['passes']
56
            self.debug("model parameters: {}".format(params))
57
            self._model = pyvw.vw(**params)
58
            self.debug('loaded model {}'.format(str(self._model)))
59
60
    @property
61
    def algorithm(self):
62
        algorithm = self.params.get('algorithm', self.DEFAULT_ALGORITHM)
63
        if algorithm not in self.SUPPORTED_ALGORITHMS:
64
            raise ConfigurationException(
65
                "{} is not a valid algorithm (allowed: {})".format(
66
                    algorithm, ', '.join(self.SUPPORTED_ALGORITHMS)),
67
                backend_id=self.backend_id)
68
        return algorithm
69
70
    @classmethod
71
    def _normalize_text(cls, project, text):
72
        ntext = ' '.join(project.analyzer.tokenize_words(text))
73
        # colon and pipe chars have special meaning in VW and must be avoided
74
        return ntext.replace(':', '').replace('|', '')
75
76
    @classmethod
77
    def _write_train_file(cls, examples, filename):
78
        with open(filename, 'w') as trainfile:
79
            for ex in examples:
80
                print(ex, file=trainfile)
81
82
    @classmethod
83
    def _uris_to_subject_ids(cls, project, uris):
84
        subject_ids = []
85
        for uri in uris:
86
            subject_id = project.subjects.by_uri(uri)
87
            if subject_id is not None:
88
                subject_ids.append(subject_id)
89
        return subject_ids
90
91
    def _format_examples(self, project, text, uris):
92
        subject_ids = self._uris_to_subject_ids(project, uris)
93
        if self.algorithm == 'multilabel_oaa':
94
            yield '{} | {}'.format(','.join(map(str, subject_ids)), text)
95
        else:
96
            for subject_id in subject_ids:
97
                yield '{} | {}'.format(subject_id + 1, text)
98
99
    def _create_train_file(self, corpus, project):
100
        self.info('creating VW train file')
101
        examples = []
102
        for doc in corpus.documents:
103
            text = self._normalize_text(project, doc.text)
104
            examples.extend(self._format_examples(project, text, doc.uris))
105
        random.shuffle(examples)
106
        annif.util.atomic_save(examples,
107
                               self._get_datadir(),
108
                               self.TRAIN_FILE,
109
                               method=self._write_train_file)
110
111
    def _convert_param(self, param, val):
112
        pspec, _ = self.VW_PARAMS[param]
113
        if isinstance(pspec, list):
114
            if val in pspec:
115
                return val
116
            raise ConfigurationException(
117
                "{} is not a valid value for {} (allowed: {})".format(
118
                    val, param, ', '.join(pspec)), backend_id=self.backend_id)
119
        try:
120
            return pspec(val)
121
        except ValueError:
122
            raise ConfigurationException(
123
                "The {} value {} cannot be converted to {}".format(
124
                    param, val, pspec), backend_id=self.backend_id)
125
126
    def _create_params(self, params):
127
        params.update({param: defaultval
128
                       for param, (_, defaultval) in self.VW_PARAMS.items()
129
                       if defaultval is not None})
130
        params.update({param: self._convert_param(param, val)
131
                       for param, val in self.params.items()
132
                       if param in self.VW_PARAMS})
133
        return params
134
135
    def _create_model(self, project):
136
        self.info('creating VW model (algorithm: {})'.format(self.algorithm))
137
        trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
138
        params = self._create_params(
139
            {'data': trainpath, self.algorithm: len(project.subjects)})
140
        if params.get('passes', 1) > 1:
141
            # need a cache file when there are multiple passes
142
            params.update({'cache': True, 'kill_cache': True})
143
        self.debug("model parameters: {}".format(params))
144
        self._model = pyvw.vw(**params)
145
        modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
146
        self._model.save(modelpath)
147
148
    def train(self, corpus, project):
149
        self._create_train_file(corpus, project)
150
        self._create_model(project)
151
152
    def _analyze_chunks(self, chunktexts, project):
153
        results = []
154
        for chunktext in chunktexts:
155
            example = ' | {}'.format(chunktext)
156
            result = self._model.predict(example)
157
            if self.algorithm == 'multilabel_oaa':
158
                # result is a list of subject IDs - need to vectorize
159
                mask = np.zeros(len(project.subjects))
160
                mask[result] = 1.0
161
                result = mask
162
            elif isinstance(result, int):
163
                # result is a single integer - need to one-hot-encode
164
                mask = np.zeros(len(project.subjects))
165
                mask[result - 1] = 1.0
166
                result = mask
167
            else:
168
                result = np.array(result)
169
            results.append(result)
170
        return VectorAnalysisResult(
171
            np.array(results).mean(axis=0), project.subjects)
172