Completed
Push — master ( 21d92b...b5edc6 )
by Osma
15s queued 12s
created

LMDBSequence.add_sample()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
5
from io import BytesIO
6
import shutil
7
import os.path
8
import numpy as np
9
from scipy.sparse import csr_matrix, csc_matrix
10
import joblib
11
import lmdb
12
from tensorflow.keras.layers import Input, Dense, Add, Flatten, Lambda, Dropout
13
from tensorflow.keras.models import Model, load_model
14
from tensorflow.keras.utils import Sequence
15
import tensorflow.keras.backend as K
16
import annif.corpus
17
import annif.project
18
import annif.util
19
from annif.exception import NotInitializedException
20
from annif.suggestion import VectorSuggestionResult
21
from . import backend
22
from . import ensemble
23
24
25
def idx_to_key(idx):
26
    """convert an integer index to a binary key for use in LMDB"""
27
    return b'%08d' % idx
28
29
30
def key_to_idx(key):
31
    """convert a binary LMDB key to an integer index"""
32
    return int(key)
33
34
35
class LMDBSequence(Sequence):
36
    """A sequence of samples stored in a LMDB database."""
37
38
    def __init__(self, txn, batch_size):
39
        self._txn = txn
40
        cursor = txn.cursor()
41
        if cursor.last():
42
            self._counter = key_to_idx(cursor.key())
43
        else:  # empty database
44
            self._counter = 0
45
        self._batch_size = batch_size
46
47
    def add_sample(self, inputs, targets):
48
        # use zero-padded 8-digit key
49
        key = idx_to_key(self._counter)
50
        self._counter += 1
51
        # convert the sample into a sparse matrix and serialize it as bytes
52
        sample = (csc_matrix(inputs), csr_matrix(targets))
53
        buf = BytesIO()
54
        joblib.dump(sample, buf)
55
        buf.seek(0)
56
        self._txn.put(key, buf.read())
57
58
    def __getitem__(self, idx):
59
        """get a particular batch of samples"""
60
        cursor = self._txn.cursor()
61
        first_key = idx * self._batch_size
62
        cursor.set_key(idx_to_key(first_key))
63
        input_arrays = []
64
        target_arrays = []
65
        for key, value in cursor.iternext():
66
            if key_to_idx(key) >= (first_key + self._batch_size):
67
                break
68
            input_csr, target_csr = joblib.load(BytesIO(value))
69
            input_arrays.append(input_csr.toarray())
70
            target_arrays.append(target_csr.toarray().flatten())
71
        return np.array(input_arrays), np.array(target_arrays)
72
73
    def __len__(self):
74
        """return the number of available batches"""
75
        return int(np.ceil(self._counter / self._batch_size))
76
77
78
class NNEnsembleBackend(
79
        backend.AnnifLearningBackend,
80
        ensemble.EnsembleBackend):
81
    """Neural network ensemble backend that combines results from multiple
82
    projects"""
83
84
    name = "nn_ensemble"
85
86
    MODEL_FILE = "nn-model.h5"
87
    LMDB_FILE = 'nn-train.mdb'
88
    LMDB_MAP_SIZE = 1024 * 1024 * 1024
89
90
    DEFAULT_PARAMS = {
91
        'nodes': 100,
92
        'dropout_rate': 0.2,
93
        'optimizer': 'adam',
94
        'epochs': 10,
95
        'learn-epochs': 1,
96
    }
97
98
    # defaults for uninitialized instances
99
    _model = None
100
101
    def default_params(self):
102
        params = {}
103
        params.update(super().default_params())
104
        params.update(self.DEFAULT_PARAMS)
105
        return params
106
107
    def initialize(self):
108
        if self._model is not None:
109
            return  # already initialized
110
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
111
        if not os.path.exists(model_filename):
112
            raise NotInitializedException(
113
                'model file {} not found'.format(model_filename),
114
                backend_id=self.backend_id)
115
        self.debug('loading Keras model from {}'.format(model_filename))
116
        self._model = load_model(model_filename)
117
118
    def _merge_hits_from_sources(self, hits_from_sources, params):
119
        score_vector = np.array([hits.vector * weight
120
                                 for hits, weight in hits_from_sources],
121
                                dtype=np.float32)
122
        results = self._model.predict(
123
            np.expand_dims(score_vector.transpose(), 0))
124
        return VectorSuggestionResult(results[0], self.project.subjects)
125
126
    def _create_model(self, sources):
127
        self.info("creating NN ensemble model")
128
129
        inputs = Input(shape=(len(self.project.subjects), len(sources)))
130
131
        flat_input = Flatten()(inputs)
132
        drop_input = Dropout(
133
            rate=float(
134
                self.params['dropout_rate']))(flat_input)
135
        hidden = Dense(int(self.params['nodes']),
136
                       activation="relu")(drop_input)
137
        drop_hidden = Dropout(rate=float(self.params['dropout_rate']))(hidden)
138
        delta = Dense(len(self.project.subjects),
139
                      kernel_initializer='zeros',
140
                      bias_initializer='zeros')(drop_hidden)
141
142
        mean = Lambda(lambda x: K.mean(x, axis=2))(inputs)
143
144
        predictions = Add()([mean, delta])
145
146
        self._model = Model(inputs=inputs, outputs=predictions)
147
        self._model.compile(optimizer=self.params['optimizer'],
148
                            loss='binary_crossentropy',
149
                            metrics=['top_k_categorical_accuracy'])
150
151
        summary = []
152
        self._model.summary(print_fn=summary.append)
153
        self.debug("Created model: \n" + "\n".join(summary))
154
155
    def _train(self, corpus, params):
156
        sources = annif.util.parse_sources(self.params['sources'])
157
        self._create_model(sources)
158
        self._fit_model(corpus, epochs=int(params['epochs']))
159
160
    def _corpus_to_vectors(self, corpus, seq):
161
        # pass corpus through all source projects
162
        sources = [(annif.project.get_project(project_id), weight)
163
                   for project_id, weight
164
                   in annif.util.parse_sources(self.params['sources'])]
165
166
        for doc in corpus.documents:
167
            doc_scores = []
168
            for source_project, weight in sources:
169
                hits = source_project.suggest(doc.text)
170
                doc_scores.append(hits.vector * weight)
171
            score_vector = np.array(doc_scores,
172
                                    dtype=np.float32).transpose()
173
            subjects = annif.corpus.SubjectSet((doc.uris, doc.labels))
174
            true_vector = subjects.as_vector(self.project.subjects)
175
            seq.add_sample(score_vector, true_vector)
176
177
    def _open_lmdb(self, cached):
178
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
179
        if not cached and os.path.exists(lmdb_path):
180
            shutil.rmtree(lmdb_path)
181
        return lmdb.open(lmdb_path, map_size=self.LMDB_MAP_SIZE, writemap=True)
182
183
    def _fit_model(self, corpus, epochs):
184
        env = self._open_lmdb(corpus == 'cached')
185
        with env.begin(write=True, buffers=True) as txn:
186
            seq = LMDBSequence(txn, batch_size=32)
187
            if corpus != 'cached':
188
                self._corpus_to_vectors(corpus, seq)
189
            else:
190
                self.info("Reusing cached training data from previous run.")
191
192
            # fit the model
193
            self._model.fit(seq, verbose=True, epochs=epochs)
194
195
        annif.util.atomic_save(
196
            self._model,
197
            self.datadir,
198
            self.MODEL_FILE)
199
200
    def _learn(self, corpus, params):
201
        self.initialize()
202
        self._fit_model(corpus, int(params['learn-epochs']))
203