Completed
Push — master ( d8a4d2...b3163f )
by Osma
16s queued 11s
created

VWMultiBackend._write_train_file()   A

Complexity

Conditions 3

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 3
nop 3
1
"""Annif backend using the Vorpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import random
5
import os.path
6
import annif.util
7
from vowpalwabbit import pyvw
8
import numpy as np
9
from annif.hit import AnalysisHit, VectorAnalysisResult
10
from annif.exception import ConfigurationException, NotInitializedException
11
from . import backend
12
from . import mixins
13
14
15
class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend):
16
    """Vorpal Wabbit multiclass/multilabel backend for Annif"""
17
18
    name = "vw_multi"
19
    needs_subject_index = True
20
21
    VW_PARAMS = {
22
        # each param specifier is a pair (allowed_values, default_value)
23
        # where allowed_values is either a type or a list of allowed values
24
        # and default_value may be None, to let VW decide by itself
25
        'bit_precision': (int, None),
26
        'learning_rate': (float, None),
27
        'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
28
        'l1': (float, None),
29
        'l2': (float, None),
30
        'passes': (int, None)
31
    }
32
33
    MODEL_FILE = 'vw-model'
34
    TRAIN_FILE = 'vw-train.txt'
35
36
    # defaults for uninitialized instances
37
    _model = None
38
39
    def initialize(self):
40
        if self._model is None:
41
            path = os.path.join(self._get_datadir(), self.MODEL_FILE)
42
            self.debug('loading VW model from {}'.format(path))
43
            if os.path.exists(path):
44
                self._model = pyvw.vw(
45
                    i=path,
46
                    quiet=True,
47
                    loss_function='logistic',
48
                    probabilities=True)
49
                self.debug('loaded model {}'.format(str(self._model)))
50
            else:
51
                raise NotInitializedException(
52
                    'model {} not found'.format(path),
53
                    backend_id=self.backend_id)
54
55
    @classmethod
56
    def _normalize_text(cls, project, text):
57
        ntext = ' '.join(project.analyzer.tokenize_words(text))
58
        # colon and pipe chars have special meaning in VW and must be avoided
59
        return ntext.replace(':', '').replace('|', '')
60
61
    def _write_train_file(self, examples, filename):
62
        with open(filename, 'w') as trainfile:
63
            for ex in examples:
64
                print(ex, file=trainfile)
65
66
    def _create_train_file(self, corpus, project):
67
        self.info('creating VW train file')
68
        examples = []
69
        for doc in corpus.documents:
70
            text = self._normalize_text(project, doc.text)
71
            for uri in doc.uris:
72
                subject_id = project.subjects.by_uri(uri)
73
                if subject_id is None:
74
                    continue
75
                exstr = '{} | {}'.format(subject_id + 1, text)
76
                examples.append(exstr)
77
        random.shuffle(examples)
78
        annif.util.atomic_save(examples,
79
                               self._get_datadir(),
80
                               self.TRAIN_FILE,
81
                               method=self._write_train_file)
82
83
    def _convert_param(self, param, val):
84
        pspec, _ = self.VW_PARAMS[param]
85
        if isinstance(pspec, list):
86
            if val in pspec:
87
                return val
88
            raise ConfigurationException(
89
                "{} is not a valid value for {} (allowed: {})".format(
90
                    val, param, ', '.join(pspec)), backend_id=self.backend_id)
91
        try:
92
            return pspec(val)
93
        except ValueError:
94
            raise ConfigurationException(
95
                "The {} value {} cannot be converted to {}".format(
96
                    param, val, pspec), backend_id=self.backend_id)
97
98
    def _create_model(self, project):
99
        self.info('creating VW model')
100
        trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
101
        params = {param: defaultval
102
                  for param, (_, defaultval) in self.VW_PARAMS.items()
103
                  if defaultval is not None}
104
        params.update({param: self._convert_param(param, val)
105
                       for param, val in self.params.items()
106
                       if param in self.VW_PARAMS})
107
        self.debug("model parameters: {}".format(params))
108
        self._model = pyvw.vw(
109
            oaa=len(project.subjects),
110
            probabilities=True,
111
            data=trainpath,
112
            **params)
113
        modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
114
        self._model.save(modelpath)
115
116
    def train(self, corpus, project):
117
        self._create_train_file(corpus, project)
118
        self._create_model(project)
119
120
    def _analyze_chunks(self, chunktexts, project):
121
        results = []
122
        for chunktext in chunktexts:
123
            example = ' | {}'.format(chunktext)
124
            results.append(np.array(self._model.predict(example)))
125
        return VectorAnalysisResult(
126
            np.array(results).mean(axis=0), project.subjects)
127