1
|
|
|
"""Annif backend using the Vowpal Wabbit multiclass and multilabel |
2
|
|
|
classifiers""" |
3
|
|
|
|
4
|
|
|
import collections |
5
|
|
|
import json |
6
|
|
|
import random |
7
|
|
|
import os.path |
8
|
|
|
import annif.util |
9
|
|
|
import annif.project |
10
|
|
|
import numpy as np |
11
|
|
|
from annif.exception import NotInitializedException |
12
|
|
|
from annif.suggestion import VectorSuggestionResult |
13
|
|
|
from . import vw_base |
14
|
|
|
from . import ensemble |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
class VWEnsembleBackend( |
18
|
|
|
ensemble.EnsembleBackend, |
19
|
|
|
vw_base.VWBaseBackend): |
20
|
|
|
"""Vowpal Wabbit ensemble backend that combines results from multiple |
21
|
|
|
projects and learns how well those projects/backends recognize |
22
|
|
|
particular subjects.""" |
23
|
|
|
|
24
|
|
|
name = "vw_ensemble" |
25
|
|
|
|
26
|
|
|
VW_PARAMS = { |
27
|
|
|
'bit_precision': (int, None), |
28
|
|
|
'learning_rate': (float, None), |
29
|
|
|
'loss_function': (['squared', 'logistic', 'hinge'], 'squared'), |
30
|
|
|
'l1': (float, None), |
31
|
|
|
'l2': (float, None), |
32
|
|
|
'passes': (int, None) |
33
|
|
|
} |
34
|
|
|
|
35
|
|
|
# number of training examples per subject, stored as a collections.Counter |
36
|
|
|
_subject_freq = None |
37
|
|
|
|
38
|
|
|
FREQ_FILE = 'subject-freq.json' |
39
|
|
|
|
40
|
|
|
# The discount rate affects how quickly the ensemble starts to trust its |
41
|
|
|
# own judgement when the amount of training data increases, versus using |
42
|
|
|
# a simple mean of scores. A higher value will mean that the model |
43
|
|
|
# adapts quicker (and possibly makes more errors) while a lower value |
44
|
|
|
# will make it more careful so that it will require more training data. |
45
|
|
|
DEFAULT_DISCOUNT_RATE = 0.01 |
46
|
|
|
|
47
|
|
|
def _load_subject_freq(self): |
48
|
|
|
path = os.path.join(self.datadir, self.FREQ_FILE) |
49
|
|
|
if not os.path.exists(path): |
50
|
|
|
raise NotInitializedException( |
51
|
|
|
'frequency file {} not found'.format(path), |
52
|
|
|
backend_id=self.backend_id) |
53
|
|
|
self.debug('loading concept frequencies from {}'.format(path)) |
54
|
|
|
with open(path) as freqf: |
55
|
|
|
# The Counter was serialized like a dictionary, need to |
56
|
|
|
# convert it back. Keys that became strings need to be turned |
57
|
|
|
# back into integers. |
58
|
|
|
self._subject_freq = collections.Counter() |
59
|
|
|
for cid, freq in json.load(freqf).items(): |
60
|
|
|
self._subject_freq[int(cid)] = freq |
|
|
|
|
61
|
|
|
self.debug('loaded frequencies for {} concepts'.format( |
62
|
|
|
len(self._subject_freq))) |
63
|
|
|
|
64
|
|
|
def initialize(self): |
65
|
|
|
if self._subject_freq is None: |
66
|
|
|
self._load_subject_freq() |
67
|
|
|
super().initialize() |
68
|
|
|
|
69
|
|
|
def _calculate_scores(self, subj_id, subj_score_vector): |
70
|
|
|
ex = self._format_example(subj_id, subj_score_vector) |
71
|
|
|
raw_score = subj_score_vector.mean() |
72
|
|
|
pred_score = (self._model.predict(ex) + 1.0) / 2.0 |
73
|
|
|
return raw_score, pred_score |
74
|
|
|
|
75
|
|
|
def _merge_hits_from_sources(self, hits_from_sources, project, params): |
76
|
|
|
score_vector = np.array([hits.vector |
77
|
|
|
for hits, _ in hits_from_sources]) |
78
|
|
|
discount_rate = float(self.params.get('discount_rate', |
79
|
|
|
self.DEFAULT_DISCOUNT_RATE)) |
80
|
|
|
result = np.zeros(score_vector.shape[1]) |
81
|
|
|
for subj_id in range(score_vector.shape[1]): |
82
|
|
|
subj_score_vector = score_vector[:, subj_id] |
83
|
|
|
if subj_score_vector.sum() > 0.0: |
84
|
|
|
raw_score, pred_score = self._calculate_scores( |
85
|
|
|
subj_id, subj_score_vector) |
86
|
|
|
raw_weight = 1.0 / \ |
87
|
|
|
((discount_rate * self._subject_freq[subj_id]) + 1) |
88
|
|
|
result[subj_id] = (raw_weight * raw_score) + \ |
89
|
|
|
(1.0 - raw_weight) * pred_score |
90
|
|
|
return VectorSuggestionResult(result, project.subjects) |
91
|
|
|
|
92
|
|
|
@property |
93
|
|
|
def _source_project_ids(self): |
94
|
|
|
sources = annif.util.parse_sources(self.params['sources']) |
95
|
|
|
return [project_id for project_id, _ in sources] |
96
|
|
|
|
97
|
|
|
def _format_example(self, subject_id, scores, true=None): |
98
|
|
|
if true is None: |
99
|
|
|
val = '' |
100
|
|
|
elif true: |
101
|
|
|
val = 1 |
102
|
|
|
else: |
103
|
|
|
val = -1 |
104
|
|
|
ex = "{} |{}".format(val, subject_id) |
105
|
|
|
for proj_idx, proj in enumerate(self._source_project_ids): |
106
|
|
|
ex += " {}:{:.6f}".format(proj, scores[proj_idx]) |
107
|
|
|
return ex |
108
|
|
|
|
109
|
|
|
def _doc_score_vector(self, doc, source_projects): |
110
|
|
|
score_vectors = [] |
111
|
|
|
for source_project in source_projects: |
112
|
|
|
hits = source_project.suggest(doc.text) |
113
|
|
|
score_vectors.append(hits.vector) |
114
|
|
|
return np.array(score_vectors) |
115
|
|
|
|
116
|
|
|
def _doc_to_example(self, doc, project, source_projects): |
117
|
|
|
examples = [] |
118
|
|
|
subjects = annif.corpus.SubjectSet((doc.uris, doc.labels)) |
119
|
|
|
true = subjects.as_vector(project.subjects) |
120
|
|
|
score_vector = self._doc_score_vector(doc, source_projects) |
121
|
|
|
for subj_id in range(len(true)): |
122
|
|
|
if true[subj_id] or score_vector[:, subj_id].sum() > 0.0: |
123
|
|
|
ex = (subj_id, self._format_example( |
124
|
|
|
subj_id, |
125
|
|
|
score_vector[:, subj_id], |
126
|
|
|
true[subj_id])) |
127
|
|
|
examples.append(ex) |
128
|
|
|
return examples |
129
|
|
|
|
130
|
|
|
def _create_examples(self, corpus, project): |
131
|
|
|
source_projects = [annif.project.get_project(project_id) |
132
|
|
|
for project_id in self._source_project_ids] |
133
|
|
|
examples = [] |
134
|
|
|
for doc in corpus.documents: |
135
|
|
|
examples += self._doc_to_example(doc, project, source_projects) |
136
|
|
|
random.shuffle(examples) |
137
|
|
|
return examples |
138
|
|
|
|
139
|
|
|
@staticmethod |
140
|
|
|
def _write_freq_file(subject_freq, filename): |
141
|
|
|
with open(filename, 'w') as freqfile: |
142
|
|
|
json.dump(subject_freq, freqfile) |
143
|
|
|
|
144
|
|
|
def _create_train_file(self, corpus, project): |
145
|
|
|
self.info('creating VW train file') |
146
|
|
|
exampledata = self._create_examples(corpus, project) |
147
|
|
|
|
148
|
|
|
subjects = [subj_id for subj_id, ex in exampledata] |
149
|
|
|
self._subject_freq = collections.Counter(subjects) |
150
|
|
|
annif.util.atomic_save(self._subject_freq, |
151
|
|
|
self.datadir, |
152
|
|
|
self.FREQ_FILE, |
153
|
|
|
method=self._write_freq_file) |
154
|
|
|
|
155
|
|
|
examples = [ex for subj_id, ex in exampledata] |
156
|
|
|
annif.util.atomic_save(examples, |
157
|
|
|
self.datadir, |
158
|
|
|
self.TRAIN_FILE, |
159
|
|
|
method=self._write_train_file) |
160
|
|
|
|
161
|
|
|
def learn(self, corpus, project): |
162
|
|
|
self.initialize() |
163
|
|
|
exampledata = self._create_examples(corpus, project) |
164
|
|
|
for subj_id, example in exampledata: |
165
|
|
|
self._model.learn(example) |
166
|
|
|
self._subject_freq[subj_id] += 1 |
167
|
|
|
modelpath = os.path.join(self.datadir, self.MODEL_FILE) |
168
|
|
|
self._model.save(modelpath) |
169
|
|
|
annif.util.atomic_save(self._subject_freq, |
170
|
|
|
self.datadir, |
171
|
|
|
self.FREQ_FILE, |
172
|
|
|
method=self._write_freq_file) |
173
|
|
|
|