annif.backend.nn_ensemble   A
last analyzed

Complexity

Total Complexity 37

Size/Duplication

Total Lines 328
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 234
dl 0
loc 328
rs 9.44
c 0
b 0
f 0
wmc 37

2 Functions

Rating   Name   Duplication   Size   Complexity  
A key_to_idx() 0 3 1
A idx_to_key() 0 3 1

14 Methods

Rating   Name   Duplication   Size   Complexity  
A LMDBSequence.__len__() 0 3 1
A LMDBSequence.__getitem__() 0 14 3
A MeanLayer.call() 0 2 1
B NNEnsembleBackend.initialize() 0 29 5
A LMDBSequence.add_sample() 0 10 1
A LMDBSequence.__init__() 0 10 2
A NNEnsembleBackend._create_model() 0 31 2
A NNEnsembleBackend._merge_source_batches() 0 28 1
B NNEnsembleBackend._corpus_to_vectors() 0 39 5
A NNEnsembleBackend.get_model_metadata() 0 12 4
A NNEnsembleBackend._open_lmdb() 0 5 3
A NNEnsembleBackend._learn() 0 8 1
A NNEnsembleBackend._train() 0 13 1
B NNEnsembleBackend._fit_model() 0 25 5
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import importlib
7
import json
8
import os.path
9
import shutil
10
import zipfile
11
from io import BytesIO
12
from typing import TYPE_CHECKING, Any
13
14
import joblib
15
import keras
16
import lmdb
17
import numpy as np
18
from keras.layers import Add, Dense, Dropout, Flatten, Input, Layer
19
from keras.models import Model
20
from keras.saving import load_model
21
from keras.utils import Sequence
22
from scipy.sparse import csc_matrix, csr_matrix
23
24
import annif.corpus
25
import annif.parallel
26
import annif.util
27
from annif.exception import (
28
    NotInitializedException,
29
    NotSupportedException,
30
    OperationFailedException,
31
)
32
from annif.suggestion import SuggestionBatch, vector_to_suggestions
33
34
from . import backend, ensemble
35
36
if TYPE_CHECKING:
37
    from tensorflow.python.framework.ops import EagerTensor
38
39
    from annif.corpus.document import DocumentCorpus
40
41
logger = annif.logger
42
43
44
def idx_to_key(idx: int) -> bytes:
45
    """convert an integer index to a binary key for use in LMDB"""
46
    return b"%08d" % idx
47
48
49
def key_to_idx(key: memoryview | bytes) -> int:
50
    """convert a binary LMDB key to an integer index"""
51
    return int(key)
52
53
54
class LMDBSequence(Sequence):
55
    """A sequence of samples stored in a LMDB database."""
56
57
    def __init__(self, txn, batch_size):
58
        super().__init__()
59
        self._txn = txn
60
        cursor = txn.cursor()
61
        if cursor.last():
62
            # Counter holds the number of samples in the database
63
            self._counter = key_to_idx(cursor.key()) + 1
64
        else:  # empty database
65
            self._counter = 0
66
        self._batch_size = batch_size
67
68
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
69
        # use zero-padded 8-digit key
70
        key = idx_to_key(self._counter)
71
        self._counter += 1
72
        # convert the sample into a sparse matrix and serialize it as bytes
73
        sample = (csc_matrix(inputs), csr_matrix(targets))
74
        buf = BytesIO()
75
        joblib.dump(sample, buf)
76
        buf.seek(0)
77
        self._txn.put(key, buf.read())
78
79
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
80
        """get a particular batch of samples"""
81
        cursor = self._txn.cursor()
82
        first_key = idx * self._batch_size
83
        cursor.set_key(idx_to_key(first_key))
84
        input_arrays = []
85
        target_arrays = []
86
        for key, value in cursor.iternext():
87
            if key_to_idx(key) >= (first_key + self._batch_size):
88
                break
89
            input_csr, target_csr = joblib.load(BytesIO(value))
90
            input_arrays.append(input_csr.toarray())
91
            target_arrays.append(target_csr.toarray().flatten())
92
        return np.array(input_arrays), np.array(target_arrays)
93
94
    def __len__(self) -> int:
95
        """return the number of available batches"""
96
        return int(np.ceil(self._counter / self._batch_size))
97
98
99
class MeanLayer(Layer):
100
    """Custom Keras layer that calculates mean values along the 2nd axis."""
101
102
    def call(self, inputs: EagerTensor) -> EagerTensor:
103
        return keras.ops.mean(inputs, axis=2)
104
105
106
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
107
    """Neural network ensemble backend that combines results from multiple
108
    projects"""
109
110
    name = "nn_ensemble"
111
112
    MODEL_FILE = "nn-model.keras"
113
    LMDB_FILE = "nn-train.mdb"
114
115
    DEFAULT_PARAMETERS = {
116
        "nodes": 100,
117
        "dropout_rate": 0.2,
118
        "optimizer": "adam",
119
        "epochs": 10,
120
        "learn-epochs": 1,
121
        "lmdb_map_size": 1024 * 1024 * 1024,
122
    }
123
124
    # defaults for uninitialized instances
125
    _model = None
126
127
    def initialize(self, parallel: bool = False) -> None:
128
        super().initialize(parallel)
129
        if self._model is not None:
130
            return  # already initialized
131
        if parallel:
132
            # Don't load TF model just before parallel execution,
133
            # since it won't work after forking worker processes
134
            return
135
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
136
        if not os.path.exists(model_filename):
137
            raise NotInitializedException(
138
                "model file {} not found".format(model_filename),
139
                backend_id=self.backend_id,
140
            )
141
        self.debug("loading Keras model from {}".format(model_filename))
142
        try:
143
            self._model = load_model(
144
                model_filename, custom_objects={"MeanLayer": MeanLayer}
145
            )
146
        except Exception as err:
147
            metadata = self.get_model_metadata(model_filename)
148
            keras_version = importlib.metadata.version("keras")
149
            message = (
150
                f"loading Keras model from {model_filename}; "
151
                f"model metadata: {metadata}; "
152
                f"you have Keras version {keras_version}. "
153
                f'Original error message: "{err}"'
154
            )
155
            raise OperationFailedException(message, backend_id=self.backend_id)
156
157
    def _merge_source_batches(
158
        self,
159
        batch_by_source: dict[str, SuggestionBatch],
160
        sources: list[tuple[str, float]],
161
        params: dict[str, Any],
162
    ) -> SuggestionBatch:
163
        src_weight = dict(sources)
164
        score_vectors = [
165
            np.array(
166
                [
167
                    [
168
                        np.sqrt(suggestions.as_vector())
169
                        * src_weight[project_id]
170
                        * len(batch_by_source)
171
                        for suggestions in batch
172
                    ]
173
                    for project_id, batch in batch_by_source.items()
174
                ],
175
                dtype=np.float32,
176
            ).transpose(1, 2, 0)
177
        ]
178
        prediction = self._model(score_vectors).numpy()
179
        return SuggestionBatch.from_sequence(
180
            [
181
                vector_to_suggestions(row, limit=int(params["limit"]))
182
                for row in prediction
183
            ],
184
            self.project.subjects,
185
        )
186
187
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
188
        self.info("creating NN ensemble model")
189
190
        inputs = Input(shape=(len(self.project.subjects), len(sources)))
191
192
        flat_input = Flatten()(inputs)
193
        drop_input = Dropout(rate=float(self.params["dropout_rate"]))(flat_input)
194
        hidden = Dense(int(self.params["nodes"]), activation="relu")(drop_input)
195
        drop_hidden = Dropout(rate=float(self.params["dropout_rate"]))(hidden)
196
        delta = Dense(
197
            len(self.project.subjects),
198
            kernel_initializer="zeros",
199
            bias_initializer="zeros",
200
        )(drop_hidden)
201
202
        mean = MeanLayer()(inputs)
203
204
        predictions = Add()([mean, delta])
205
206
        self._model = Model(inputs=inputs, outputs=predictions)
207
        self._model.compile(
208
            optimizer=self.params["optimizer"],
209
            loss="binary_crossentropy",
210
            metrics=["top_k_categorical_accuracy"],
211
        )
212
        if "lr" in self.params:
213
            self._model.optimizer.learning_rate.assign(float(self.params["lr"]))
214
215
        summary = []
216
        self._model.summary(print_fn=summary.append)
217
        self.debug("Created model: \n" + "\n".join(summary))
218
219
    def _train(
220
        self,
221
        corpus: DocumentCorpus,
222
        params: dict[str, Any],
223
        jobs: int = 0,
224
    ) -> None:
225
        sources = annif.util.parse_sources(self.params["sources"])
226
        self._create_model(sources)
227
        self._fit_model(
228
            corpus,
229
            epochs=int(params["epochs"]),
230
            lmdb_map_size=int(params["lmdb_map_size"]),
231
            n_jobs=jobs,
232
        )
233
234
    def _corpus_to_vectors(
235
        self,
236
        corpus: DocumentCorpus,
237
        seq: LMDBSequence,
238
        n_jobs: int,
239
    ) -> None:
240
        # pass corpus through all source projects
241
        sources = dict(annif.util.parse_sources(self.params["sources"]))
242
243
        # initialize the source projects before forking, to save memory
244
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
245
        for project_id in sources.keys():
246
            project = self.project.registry.get_project(project_id)
247
            project.initialize(parallel=True)
248
249
        psmap = annif.parallel.ProjectSuggestMap(
250
            self.project.registry,
251
            list(sources.keys()),
252
            backend_params=None,
253
            limit=None,
254
            threshold=0.0,
255
        )
256
257
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
258
259
        self.info("Processing training documents...")
260
        with pool_class(jobs) as pool:
261
            for hits, subject_set in pool.imap_unordered(
262
                psmap.suggest, corpus.documents
263
            ):
264
                doc_scores = []
265
                for project_id, p_hits in hits.items():
266
                    vector = p_hits.as_vector()
267
                    doc_scores.append(
268
                        np.sqrt(vector) * sources[project_id] * len(sources)
269
                    )
270
                score_vector = np.array(doc_scores, dtype=np.float32).transpose()
271
                true_vector = subject_set.as_vector(len(self.project.subjects))
272
                seq.add_sample(score_vector, true_vector)
273
274
    def _open_lmdb(self, cached, lmdb_map_size):
275
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
276
        if not cached and os.path.exists(lmdb_path):
277
            shutil.rmtree(lmdb_path)
278
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
279
280
    def _fit_model(
281
        self,
282
        corpus: DocumentCorpus,
283
        epochs: int,
284
        lmdb_map_size: int,
285
        n_jobs: int = 1,
286
    ) -> None:
287
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
288
        if corpus != "cached":
289
            if corpus.is_empty():
290
                raise NotSupportedException(
291
                    "Cannot train nn_ensemble project with no documents"
292
                )
293
            with env.begin(write=True, buffers=True) as txn:
294
                seq = LMDBSequence(txn, batch_size=32)
295
                self._corpus_to_vectors(corpus, seq, n_jobs)
296
        else:
297
            self.info("Reusing cached training data from previous run.")
298
        # fit the model using a read-only view of the LMDB
299
        self.info("Training neural network model...")
300
        with env.begin(buffers=True) as txn:
301
            seq = LMDBSequence(txn, batch_size=32)
302
            self._model.fit(seq, verbose=True, epochs=epochs)
303
304
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
305
306
    def _learn(
307
        self,
308
        corpus: DocumentCorpus,
309
        params: dict[str, Any],
310
    ) -> None:
311
        self.initialize()
312
        self._fit_model(
313
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
314
        )
315
316
    def get_model_metadata(self, model_filename: str) -> dict | None:
317
        """Read metadata from Keras model files."""
318
319
        try:
320
            with zipfile.ZipFile(model_filename, "r") as zip:
321
                with zip.open("metadata.json") as metadata_file:
322
                    metadata_str = metadata_file.read().decode("utf-8")
323
                    metadata = json.loads(metadata_str)
324
                    return metadata
325
        except Exception:
326
            self.warning(f"Failed to read metadata from {model_filename}")
327
            return None
328