|
1
|
|
|
"""Annif backend using the Vowpal Wabbit multiclass and multilabel |
|
2
|
|
|
classifiers""" |
|
3
|
|
|
|
|
4
|
|
|
import collections |
|
5
|
|
|
import json |
|
6
|
|
|
import random |
|
7
|
|
|
import os.path |
|
8
|
|
|
import annif.util |
|
9
|
|
|
import annif.project |
|
10
|
|
|
import numpy as np |
|
11
|
|
|
from annif.exception import NotInitializedException |
|
12
|
|
|
from annif.suggestion import VectorSuggestionResult |
|
13
|
|
|
from . import vw_base |
|
14
|
|
|
from . import ensemble |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
class VWEnsembleBackend( |
|
18
|
|
|
ensemble.EnsembleBackend, |
|
19
|
|
|
vw_base.VWBaseBackend): |
|
20
|
|
|
"""Vowpal Wabbit ensemble backend that combines results from multiple |
|
21
|
|
|
projects and learns how well those projects/backends recognize |
|
22
|
|
|
particular subjects.""" |
|
23
|
|
|
|
|
24
|
|
|
name = "vw_ensemble" |
|
25
|
|
|
|
|
26
|
|
|
VW_PARAMS = { |
|
27
|
|
|
'bit_precision': (int, None), |
|
28
|
|
|
'learning_rate': (float, None), |
|
29
|
|
|
'loss_function': (['squared', 'logistic', 'hinge'], 'squared'), |
|
30
|
|
|
'l1': (float, None), |
|
31
|
|
|
'l2': (float, None), |
|
32
|
|
|
'passes': (int, None) |
|
33
|
|
|
} |
|
34
|
|
|
|
|
35
|
|
|
# number of training examples per subject, stored as a collections.Counter |
|
36
|
|
|
_subject_freq = None |
|
37
|
|
|
|
|
38
|
|
|
FREQ_FILE = 'subject-freq.json' |
|
39
|
|
|
|
|
40
|
|
|
# The discount rate affects how quickly the ensemble starts to trust its |
|
41
|
|
|
# own judgement when the amount of training data increases, versus using |
|
42
|
|
|
# a simple mean of scores. A higher value will mean that the model |
|
43
|
|
|
# adapts quicker (and possibly makes more errors) while a lower value |
|
44
|
|
|
# will make it more careful so that it will require more training data. |
|
45
|
|
|
DEFAULT_DISCOUNT_RATE = 0.01 |
|
46
|
|
|
|
|
47
|
|
|
# score threshold for "zero features": scores lower than this will be |
|
48
|
|
|
# considered zero and marked with a zero feature given to VW |
|
49
|
|
|
ZERO_THRESHOLD = 0.001 |
|
50
|
|
|
|
|
51
|
|
|
def _load_subject_freq(self): |
|
52
|
|
|
path = os.path.join(self.datadir, self.FREQ_FILE) |
|
53
|
|
|
if not os.path.exists(path): |
|
54
|
|
|
raise NotInitializedException( |
|
55
|
|
|
'frequency file {} not found'.format(path), |
|
56
|
|
|
backend_id=self.backend_id) |
|
57
|
|
|
self.debug('loading concept frequencies from {}'.format(path)) |
|
58
|
|
|
with open(path) as freqf: |
|
59
|
|
|
# The Counter was serialized like a dictionary, need to |
|
60
|
|
|
# convert it back. Keys that became strings need to be turned |
|
61
|
|
|
# back into integers. |
|
62
|
|
|
self._subject_freq = collections.Counter() |
|
63
|
|
|
for cid, freq in json.load(freqf).items(): |
|
64
|
|
|
self._subject_freq[int(cid)] = freq |
|
|
|
|
|
|
65
|
|
|
self.debug('loaded frequencies for {} concepts'.format( |
|
66
|
|
|
len(self._subject_freq))) |
|
67
|
|
|
|
|
68
|
|
|
def initialize(self): |
|
69
|
|
|
if self._subject_freq is None: |
|
70
|
|
|
self._load_subject_freq() |
|
71
|
|
|
super().initialize() |
|
72
|
|
|
|
|
73
|
|
|
def _calculate_scores(self, subj_id, subj_score_vector): |
|
74
|
|
|
ex = self._format_example(subj_id, subj_score_vector) |
|
75
|
|
|
raw_score = subj_score_vector.mean() |
|
76
|
|
|
pred_score = (self._model.predict(ex) + 1.0) / 2.0 |
|
77
|
|
|
return raw_score, pred_score |
|
78
|
|
|
|
|
79
|
|
|
def _merge_hits_from_sources(self, hits_from_sources, project, params): |
|
80
|
|
|
score_vector = np.array([hits.vector |
|
81
|
|
|
for hits, _ in hits_from_sources]) |
|
82
|
|
|
discount_rate = float(self.params.get('discount_rate', |
|
83
|
|
|
self.DEFAULT_DISCOUNT_RATE)) |
|
84
|
|
|
result = np.zeros(score_vector.shape[1]) |
|
85
|
|
|
for subj_id in range(score_vector.shape[1]): |
|
86
|
|
|
subj_score_vector = score_vector[:, subj_id] |
|
87
|
|
|
if subj_score_vector.sum() > 0.0: |
|
88
|
|
|
raw_score, pred_score = self._calculate_scores( |
|
89
|
|
|
subj_id, subj_score_vector) |
|
90
|
|
|
raw_weight = 1.0 / \ |
|
91
|
|
|
((discount_rate * self._subject_freq[subj_id]) + 1) |
|
92
|
|
|
result[subj_id] = (raw_weight * raw_score) + \ |
|
93
|
|
|
(1.0 - raw_weight) * pred_score |
|
94
|
|
|
return VectorSuggestionResult(result, project.subjects) |
|
95
|
|
|
|
|
96
|
|
|
@property |
|
97
|
|
|
def _source_project_ids(self): |
|
98
|
|
|
sources = annif.util.parse_sources(self.params['sources']) |
|
99
|
|
|
return [project_id for project_id, _ in sources] |
|
100
|
|
|
|
|
101
|
|
|
@staticmethod |
|
102
|
|
|
def _format_value(true): |
|
103
|
|
|
if true is None: |
|
104
|
|
|
return '' |
|
105
|
|
|
elif true: |
|
106
|
|
|
return 1 |
|
107
|
|
|
else: |
|
108
|
|
|
return -1 |
|
109
|
|
|
|
|
110
|
|
|
def _format_example(self, subject_id, scores, true=None): |
|
111
|
|
|
features = " ".join(["{}:{:.6f}".format(proj, scores[proj_idx]) |
|
112
|
|
|
for proj_idx, proj |
|
113
|
|
|
in enumerate(self._source_project_ids)]) |
|
114
|
|
|
zero_features = " ".join(["zero^{}".format(proj) |
|
115
|
|
|
for proj_idx, proj |
|
116
|
|
|
in enumerate(self._source_project_ids) |
|
117
|
|
|
if scores[proj_idx] < self.ZERO_THRESHOLD]) |
|
118
|
|
|
return "{} |raw {} {} |{} {} {}".format( |
|
119
|
|
|
self._format_value(true), |
|
120
|
|
|
features, |
|
121
|
|
|
zero_features, |
|
122
|
|
|
subject_id, |
|
123
|
|
|
features, |
|
124
|
|
|
zero_features) |
|
125
|
|
|
|
|
126
|
|
|
def _doc_score_vector(self, doc, source_projects): |
|
127
|
|
|
score_vectors = [] |
|
128
|
|
|
for source_project in source_projects: |
|
129
|
|
|
hits = source_project.suggest(doc.text) |
|
130
|
|
|
score_vectors.append(hits.vector) |
|
131
|
|
|
return np.array(score_vectors) |
|
132
|
|
|
|
|
133
|
|
|
def _doc_to_example(self, doc, project, source_projects): |
|
134
|
|
|
examples = [] |
|
135
|
|
|
subjects = annif.corpus.SubjectSet((doc.uris, doc.labels)) |
|
136
|
|
|
true = subjects.as_vector(project.subjects) |
|
137
|
|
|
score_vector = self._doc_score_vector(doc, source_projects) |
|
138
|
|
|
for subj_id in range(len(true)): |
|
139
|
|
|
if true[subj_id] \ |
|
140
|
|
|
or score_vector[:, subj_id].sum() >= self.ZERO_THRESHOLD: |
|
141
|
|
|
ex = (subj_id, self._format_example( |
|
142
|
|
|
subj_id, |
|
143
|
|
|
score_vector[:, subj_id], |
|
144
|
|
|
true[subj_id])) |
|
145
|
|
|
examples.append(ex) |
|
146
|
|
|
return examples |
|
147
|
|
|
|
|
148
|
|
|
def _create_examples(self, corpus, project): |
|
149
|
|
|
source_projects = [annif.project.get_project(project_id) |
|
150
|
|
|
for project_id in self._source_project_ids] |
|
151
|
|
|
examples = [] |
|
152
|
|
|
for doc in corpus.documents: |
|
153
|
|
|
examples += self._doc_to_example(doc, project, source_projects) |
|
154
|
|
|
random.shuffle(examples) |
|
155
|
|
|
return examples |
|
156
|
|
|
|
|
157
|
|
|
def _create_model(self, project): |
|
158
|
|
|
# add interactions between raw (descriptor-invariant) features to |
|
159
|
|
|
# the mix |
|
160
|
|
|
super()._create_model(project, {'q': 'rr'}) |
|
161
|
|
|
|
|
162
|
|
|
@staticmethod |
|
163
|
|
|
def _write_freq_file(subject_freq, filename): |
|
164
|
|
|
with open(filename, 'w') as freqfile: |
|
165
|
|
|
json.dump(subject_freq, freqfile) |
|
166
|
|
|
|
|
167
|
|
|
def _create_train_file(self, corpus, project): |
|
168
|
|
|
self.info('creating VW train file') |
|
169
|
|
|
exampledata = self._create_examples(corpus, project) |
|
170
|
|
|
|
|
171
|
|
|
subjects = [subj_id for subj_id, ex in exampledata] |
|
172
|
|
|
self._subject_freq = collections.Counter(subjects) |
|
173
|
|
|
annif.util.atomic_save(self._subject_freq, |
|
174
|
|
|
self.datadir, |
|
175
|
|
|
self.FREQ_FILE, |
|
176
|
|
|
method=self._write_freq_file) |
|
177
|
|
|
|
|
178
|
|
|
examples = [ex for subj_id, ex in exampledata] |
|
179
|
|
|
annif.util.atomic_save(examples, |
|
180
|
|
|
self.datadir, |
|
181
|
|
|
self.TRAIN_FILE, |
|
182
|
|
|
method=self._write_train_file) |
|
183
|
|
|
|
|
184
|
|
|
def learn(self, corpus, project): |
|
185
|
|
|
self.initialize() |
|
186
|
|
|
exampledata = self._create_examples(corpus, project) |
|
187
|
|
|
for subj_id, example in exampledata: |
|
188
|
|
|
self._model.learn(example) |
|
189
|
|
|
self._subject_freq[subj_id] += 1 |
|
190
|
|
|
modelpath = os.path.join(self.datadir, self.MODEL_FILE) |
|
191
|
|
|
self._model.save(modelpath) |
|
192
|
|
|
annif.util.atomic_save(self._subject_freq, |
|
193
|
|
|
self.datadir, |
|
194
|
|
|
self.FREQ_FILE, |
|
195
|
|
|
method=self._write_freq_file) |
|
196
|
|
|
|