Completed
Push — master ( 09c408...852165 )
by Osma
16s queued 11s
created

annif.backend.vw_ensemble   A

Complexity

Total Complexity 28

Size/Duplication

Total Lines 173
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 28
eloc 136
dl 0
loc 173
rs 10
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A VWEnsembleBackend._write_freq_file() 0 4 2
A VWEnsembleBackend._load_subject_freq() 0 16 4
A VWEnsembleBackend._create_train_file() 0 16 1
A VWEnsembleBackend.initialize() 0 4 2
A VWEnsembleBackend._calculate_scores() 0 5 1
A VWEnsembleBackend._source_project_ids() 0 4 1
A VWEnsembleBackend._create_examples() 0 8 2
A VWEnsembleBackend._merge_hits_from_sources() 0 16 3
A VWEnsembleBackend._doc_score_vector() 0 6 2
A VWEnsembleBackend.learn() 0 12 2
A VWEnsembleBackend._doc_to_example() 0 13 4
A VWEnsembleBackend._format_example() 0 11 4
1
"""Annif backend using the Vowpal Wabbit multiclass and multilabel
2
classifiers"""
3
4
import collections
5
import json
6
import random
7
import os.path
8
import annif.util
9
import annif.project
10
import numpy as np
11
from annif.exception import NotInitializedException
12
from annif.suggestion import VectorSuggestionResult
13
from . import vw_base
14
from . import ensemble
15
16
17
class VWEnsembleBackend(
18
        ensemble.EnsembleBackend,
19
        vw_base.VWBaseBackend):
20
    """Vowpal Wabbit ensemble backend that combines results from multiple
21
    projects and learns how well those projects/backends recognize
22
    particular subjects."""
23
24
    name = "vw_ensemble"
25
26
    VW_PARAMS = {
27
        'bit_precision': (int, None),
28
        'learning_rate': (float, None),
29
        'loss_function': (['squared', 'logistic', 'hinge'], 'squared'),
30
        'l1': (float, None),
31
        'l2': (float, None),
32
        'passes': (int, None)
33
    }
34
35
    # number of training examples per subject, stored as a collections.Counter
36
    _subject_freq = None
37
38
    FREQ_FILE = 'subject-freq.json'
39
40
    # The discount rate affects how quickly the ensemble starts to trust its
41
    # own judgement when the amount of training data increases, versus using
42
    # a simple mean of scores. A higher value will mean that the model
43
    # adapts quicker (and possibly makes more errors) while a lower value
44
    # will make it more careful so that it will require more training data.
45
    DEFAULT_DISCOUNT_RATE = 0.01
46
47
    def _load_subject_freq(self):
48
        path = os.path.join(self.datadir, self.FREQ_FILE)
49
        if not os.path.exists(path):
50
            raise NotInitializedException(
51
                'frequency file {} not found'.format(path),
52
                backend_id=self.backend_id)
53
        self.debug('loading concept frequencies from {}'.format(path))
54
        with open(path) as freqf:
55
            # The Counter was serialized like a dictionary, need to
56
            # convert it back. Keys that became strings need to be turned
57
            # back into integers.
58
            self._subject_freq = collections.Counter()
59
            for cid, freq in json.load(freqf).items():
60
                self._subject_freq[int(cid)] = freq
1 ignored issue
show
Comprehensibility Best Practice introduced by
The variable int does not seem to be defined.
Loading history...
61
        self.debug('loaded frequencies for {} concepts'.format(
62
            len(self._subject_freq)))
63
64
    def initialize(self):
65
        if self._subject_freq is None:
66
            self._load_subject_freq()
67
        super().initialize()
68
69
    def _calculate_scores(self, subj_id, subj_score_vector):
70
        ex = self._format_example(subj_id, subj_score_vector)
71
        raw_score = subj_score_vector.mean()
72
        pred_score = (self._model.predict(ex) + 1.0) / 2.0
73
        return raw_score, pred_score
74
75
    def _merge_hits_from_sources(self, hits_from_sources, project, params):
76
        score_vector = np.array([hits.vector
77
                                 for hits, _ in hits_from_sources])
78
        discount_rate = float(self.params.get('discount_rate',
79
                                              self.DEFAULT_DISCOUNT_RATE))
80
        result = np.zeros(score_vector.shape[1])
81
        for subj_id in range(score_vector.shape[1]):
82
            subj_score_vector = score_vector[:, subj_id]
83
            if subj_score_vector.sum() > 0.0:
84
                raw_score, pred_score = self._calculate_scores(
85
                    subj_id, subj_score_vector)
86
                raw_weight = 1.0 / \
87
                    ((discount_rate * self._subject_freq[subj_id]) + 1)
88
                result[subj_id] = (raw_weight * raw_score) + \
89
                    (1.0 - raw_weight) * pred_score
90
        return VectorSuggestionResult(result, project.subjects)
91
92
    @property
93
    def _source_project_ids(self):
94
        sources = annif.util.parse_sources(self.params['sources'])
95
        return [project_id for project_id, _ in sources]
96
97
    def _format_example(self, subject_id, scores, true=None):
98
        if true is None:
99
            val = ''
100
        elif true:
101
            val = 1
102
        else:
103
            val = -1
104
        ex = "{} |{}".format(val, subject_id)
105
        for proj_idx, proj in enumerate(self._source_project_ids):
106
            ex += " {}:{:.6f}".format(proj, scores[proj_idx])
107
        return ex
108
109
    def _doc_score_vector(self, doc, source_projects):
110
        score_vectors = []
111
        for source_project in source_projects:
112
            hits = source_project.suggest(doc.text)
113
            score_vectors.append(hits.vector)
114
        return np.array(score_vectors)
115
116
    def _doc_to_example(self, doc, project, source_projects):
117
        examples = []
118
        subjects = annif.corpus.SubjectSet((doc.uris, doc.labels))
119
        true = subjects.as_vector(project.subjects)
120
        score_vector = self._doc_score_vector(doc, source_projects)
121
        for subj_id in range(len(true)):
122
            if true[subj_id] or score_vector[:, subj_id].sum() > 0.0:
123
                ex = (subj_id, self._format_example(
124
                    subj_id,
125
                    score_vector[:, subj_id],
126
                    true[subj_id]))
127
                examples.append(ex)
128
        return examples
129
130
    def _create_examples(self, corpus, project):
131
        source_projects = [annif.project.get_project(project_id)
132
                           for project_id in self._source_project_ids]
133
        examples = []
134
        for doc in corpus.documents:
135
            examples += self._doc_to_example(doc, project, source_projects)
136
        random.shuffle(examples)
137
        return examples
138
139
    @staticmethod
140
    def _write_freq_file(subject_freq, filename):
141
        with open(filename, 'w') as freqfile:
142
            json.dump(subject_freq, freqfile)
143
144
    def _create_train_file(self, corpus, project):
145
        self.info('creating VW train file')
146
        exampledata = self._create_examples(corpus, project)
147
148
        subjects = [subj_id for subj_id, ex in exampledata]
149
        self._subject_freq = collections.Counter(subjects)
150
        annif.util.atomic_save(self._subject_freq,
151
                               self.datadir,
152
                               self.FREQ_FILE,
153
                               method=self._write_freq_file)
154
155
        examples = [ex for subj_id, ex in exampledata]
156
        annif.util.atomic_save(examples,
157
                               self.datadir,
158
                               self.TRAIN_FILE,
159
                               method=self._write_train_file)
160
161
    def learn(self, corpus, project):
162
        self.initialize()
163
        exampledata = self._create_examples(corpus, project)
164
        for subj_id, example in exampledata:
165
            self._model.learn(example)
166
            self._subject_freq[subj_id] += 1
167
        modelpath = os.path.join(self.datadir, self.MODEL_FILE)
168
        self._model.save(modelpath)
169
        annif.util.atomic_save(self._subject_freq,
170
                               self.datadir,
171
                               self.FREQ_FILE,
172
                               method=self._write_freq_file)
173