Completed
Push — master ( 09c408...852165 )
by Osma
16s queued 11s
created

annif.backend.vw_multi   A

Complexity

Total Complexity 28

Size/Duplication

Total Lines 142
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 28
eloc 113
dl 0
loc 142
rs 10
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A VWMultiBackend._suggest_chunks() 0 14 4
A VWMultiBackend._inputs_to_exampletext() 0 10 4
A VWMultiBackend._normalize_text() 0 4 1
A VWMultiBackend.inputs() 0 4 1
A VWMultiBackend._create_model() 0 3 1
A VWMultiBackend._create_examples() 0 9 3
A VWMultiBackend._convert_result() 0 14 3
A VWMultiBackend._cleanup_text() 0 4 1
A VWMultiBackend._get_input() 0 10 2
A VWMultiBackend._uris_to_subject_ids() 0 8 3
A VWMultiBackend.algorithm() 0 9 2
A VWMultiBackend._format_examples() 0 7 3
1
"""Annif backend using the Vowpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import random
5
import numpy as np
6
import annif.project
7
from annif.suggestion import ListSuggestionResult, VectorSuggestionResult
8
from annif.exception import ConfigurationException
9
from . import vw_base
10
from . import mixins
11
12
13
class VWMultiBackend(mixins.ChunkingBackend, vw_base.VWBaseBackend):
14
    """Vowpal Wabbit multiclass/multilabel backend for Annif"""
15
16
    name = "vw_multi"
17
    needs_subject_index = True
18
19
    VW_PARAMS = {
20
        'bit_precision': (int, None),
21
        'ngram': (lambda x: '_{}'.format(int(x)), None),
22
        'learning_rate': (float, None),
23
        'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'),
24
        'l1': (float, None),
25
        'l2': (float, None),
26
        'passes': (int, None),
27
        'probabilities': (bool, None)
28
    }
29
30
    DEFAULT_ALGORITHM = 'oaa'
31
    SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa')
32
33
    DEFAULT_INPUTS = '_text_'
34
35
    @property
36
    def algorithm(self):
37
        algorithm = self.params.get('algorithm', self.DEFAULT_ALGORITHM)
38
        if algorithm not in self.SUPPORTED_ALGORITHMS:
39
            raise ConfigurationException(
40
                "{} is not a valid algorithm (allowed: {})".format(
41
                    algorithm, ', '.join(self.SUPPORTED_ALGORITHMS)),
42
                backend_id=self.backend_id)
43
        return algorithm
44
45
    @property
46
    def inputs(self):
47
        inputs = self.params.get('inputs', self.DEFAULT_INPUTS)
48
        return inputs.split(',')
49
50
    @staticmethod
51
    def _cleanup_text(text):
52
        # colon and pipe chars have special meaning in VW and must be avoided
53
        return text.replace(':', '').replace('|', '')
54
55
    @staticmethod
56
    def _normalize_text(project, text):
57
        ntext = ' '.join(project.analyzer.tokenize_words(text))
58
        return VWMultiBackend._cleanup_text(ntext)
59
60
    @staticmethod
61
    def _uris_to_subject_ids(project, uris):
62
        subject_ids = []
63
        for uri in uris:
64
            subject_id = project.subjects.by_uri(uri)
65
            if subject_id is not None:
66
                subject_ids.append(subject_id)
67
        return subject_ids
68
69
    def _format_examples(self, project, text, uris):
70
        subject_ids = self._uris_to_subject_ids(project, uris)
71
        if self.algorithm == 'multilabel_oaa':
72
            yield '{} {}'.format(','.join(map(str, subject_ids)), text)
73
        else:
74
            for subject_id in subject_ids:
75
                yield '{} {}'.format(subject_id + 1, text)
76
77
    def _get_input(self, input, project, text):
78
        if input == '_text_':
79
            return self._normalize_text(project, text)
80
        else:
81
            proj = annif.project.get_project(input)
82
            result = proj.suggest(text)
83
            features = [
84
                '{}:{}'.format(self._cleanup_text(hit.uri), hit.score)
85
                for hit in result.hits]
86
            return ' '.join(features)
87
88
    def _inputs_to_exampletext(self, project, text):
89
        namespaces = {}
90
        for input in self.inputs:
91
            inputtext = self._get_input(input, project, text)
92
            if inputtext:
93
                namespaces[input] = inputtext
94
        if not namespaces:
95
            return None
96
        return ' '.join(['|{} {}'.format(namespace, featurestr)
97
                         for namespace, featurestr in namespaces.items()])
98
99
    def _create_examples(self, corpus, project):
100
        examples = []
101
        for doc in corpus.documents:
102
            text = self._inputs_to_exampletext(project, doc.text)
103
            if not text:
104
                continue
105
            examples.extend(self._format_examples(project, text, doc.uris))
106
        random.shuffle(examples)
107
        return examples
108
109
    def _create_model(self, project):
110
        self.info('creating VW model (algorithm: {})'.format(self.algorithm))
111
        super()._create_model(project, {self.algorithm: len(project.subjects)})
112
113
    def _convert_result(self, result, project):
114
        if self.algorithm == 'multilabel_oaa':
115
            # result is a list of subject IDs - need to vectorize
116
            mask = np.zeros(len(project.subjects))
117
            mask[result] = 1.0
118
            return mask
119
        elif isinstance(result, int):
120
            # result is a single integer - need to one-hot-encode
121
            mask = np.zeros(len(project.subjects))
122
            mask[result - 1] = 1.0
123
            return mask
124
        else:
125
            # result is a list of scores (probabilities or binary 1/0)
126
            return np.array(result)
127
128
    def _suggest_chunks(self, chunktexts, project):
129
        results = []
130
        for chunktext in chunktexts:
131
            exampletext = self._inputs_to_exampletext(project, chunktext)
132
            if not exampletext:
133
                continue
134
            example = ' {}'.format(exampletext)
135
            result = self._model.predict(example)
136
            results.append(self._convert_result(result, project))
137
        if not results:  # empty result
138
            return ListSuggestionResult(
139
                hits=[], subject_index=project.subjects)
140
        return VectorSuggestionResult(
141
            np.array(results).mean(axis=0), project.subjects)
142