annif.backend.nn_ensemble   A
last analyzed

Complexity

Total Complexity 37

Size/Duplication

Total Lines 326
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 233
dl 0
loc 326
rs 9.44
c 0
b 0
f 0
wmc 37

2 Functions

Rating   Name   Duplication   Size   Complexity  
A key_to_idx() 0 3 1
A idx_to_key() 0 3 1

14 Methods

Rating   Name   Duplication   Size   Complexity  
A LMDBSequence.__len__() 0 3 1
A NNEnsembleBackend._create_model() 0 31 2
A NNEnsembleBackend._merge_source_batches() 0 26 1
A LMDBSequence.__getitem__() 0 14 3
A MeanLayer.call() 0 2 1
B NNEnsembleBackend._corpus_to_vectors() 0 39 5
B NNEnsembleBackend.initialize() 0 29 5
A NNEnsembleBackend.get_model_metadata() 0 12 4
A LMDBSequence.add_sample() 0 10 1
A NNEnsembleBackend._open_lmdb() 0 5 3
A NNEnsembleBackend._learn() 0 8 1
A NNEnsembleBackend._train() 0 13 1
A LMDBSequence.__init__() 0 10 2
B NNEnsembleBackend._fit_model() 0 25 5
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import importlib
7
import json
8
import os.path
9
import shutil
10
import zipfile
11
from io import BytesIO
12
from typing import TYPE_CHECKING, Any
13
14
import joblib
15
import keras
16
import lmdb
17
import numpy as np
18
from keras.layers import Add, Dense, Dropout, Flatten, Input, Layer
19
from keras.models import Model
20
from keras.saving import load_model
21
from keras.utils import Sequence
22
from scipy.sparse import csc_matrix, csr_matrix
23
24
import annif.corpus
25
import annif.parallel
26
import annif.util
27
from annif.exception import (
28
    NotInitializedException,
29
    NotSupportedException,
30
    OperationFailedException,
31
)
32
from annif.suggestion import SuggestionBatch, vector_to_suggestions
33
34
from . import backend, ensemble
35
36
if TYPE_CHECKING:
37
    from tensorflow.python.framework.ops import EagerTensor
38
39
    from annif.corpus.document import DocumentCorpus
40
41
logger = annif.logger
42
43
44
def idx_to_key(idx: int) -> bytes:
45
    """convert an integer index to a binary key for use in LMDB"""
46
    return b"%08d" % idx
47
48
49
def key_to_idx(key: memoryview | bytes) -> int:
50
    """convert a binary LMDB key to an integer index"""
51
    return int(key)
52
53
54
class LMDBSequence(Sequence):
55
    """A sequence of samples stored in a LMDB database."""
56
57
    def __init__(self, txn, batch_size):
58
        super().__init__()
59
        self._txn = txn
60
        cursor = txn.cursor()
61
        if cursor.last():
62
            # Counter holds the number of samples in the database
63
            self._counter = key_to_idx(cursor.key()) + 1
64
        else:  # empty database
65
            self._counter = 0
66
        self._batch_size = batch_size
67
68
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
69
        # use zero-padded 8-digit key
70
        key = idx_to_key(self._counter)
71
        self._counter += 1
72
        # convert the sample into a sparse matrix and serialize it as bytes
73
        sample = (csc_matrix(inputs), csr_matrix(targets))
74
        buf = BytesIO()
75
        joblib.dump(sample, buf)
76
        buf.seek(0)
77
        self._txn.put(key, buf.read())
78
79
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
80
        """get a particular batch of samples"""
81
        cursor = self._txn.cursor()
82
        first_key = idx * self._batch_size
83
        cursor.set_key(idx_to_key(first_key))
84
        input_arrays = []
85
        target_arrays = []
86
        for key, value in cursor.iternext():
87
            if key_to_idx(key) >= (first_key + self._batch_size):
88
                break
89
            input_csr, target_csr = joblib.load(BytesIO(value))
90
            input_arrays.append(input_csr.toarray())
91
            target_arrays.append(target_csr.toarray().flatten())
92
        return np.array(input_arrays), np.array(target_arrays)
93
94
    def __len__(self) -> int:
95
        """return the number of available batches"""
96
        return int(np.ceil(self._counter / self._batch_size))
97
98
99
class MeanLayer(Layer):
100
    """Custom Keras layer that calculates mean values along the 2nd axis."""
101
102
    def call(self, inputs: EagerTensor) -> EagerTensor:
103
        return keras.ops.mean(inputs, axis=2)
104
105
106
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
107
    """Neural network ensemble backend that combines results from multiple
108
    projects"""
109
110
    name = "nn_ensemble"
111
112
    MODEL_FILE = "nn-model.keras"
113
    LMDB_FILE = "nn-train.mdb"
114
115
    DEFAULT_PARAMETERS = {
116
        "nodes": 100,
117
        "dropout_rate": 0.2,
118
        "optimizer": "adam",
119
        "epochs": 10,
120
        "learn-epochs": 1,
121
        "lmdb_map_size": 1024 * 1024 * 1024,
122
    }
123
124
    # defaults for uninitialized instances
125
    _model = None
126
127
    def initialize(self, parallel: bool = False) -> None:
128
        super().initialize(parallel)
129
        if self._model is not None:
130
            return  # already initialized
131
        if parallel:
132
            # Don't load TF model just before parallel execution,
133
            # since it won't work after forking worker processes
134
            return
135
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
136
        if not os.path.exists(model_filename):
137
            raise NotInitializedException(
138
                "model file {} not found".format(model_filename),
139
                backend_id=self.backend_id,
140
            )
141
        self.debug("loading Keras model from {}".format(model_filename))
142
        try:
143
            self._model = load_model(
144
                model_filename, custom_objects={"MeanLayer": MeanLayer}
145
            )
146
        except Exception as err:
147
            metadata = self.get_model_metadata(model_filename)
148
            keras_version = importlib.metadata.version("keras")
149
            message = (
150
                f"loading Keras model from {model_filename}; "
151
                f"model metadata: {metadata}; "
152
                f"you have Keras version {keras_version}. "
153
                f'Original error message: "{err}"'
154
            )
155
            raise OperationFailedException(message, backend_id=self.backend_id)
156
157
    def _merge_source_batches(
158
        self,
159
        batch_by_source: dict[str, SuggestionBatch],
160
        sources: list[tuple[str, float]],
161
        params: dict[str, Any],
162
    ) -> SuggestionBatch:
163
        src_weight = dict(sources)
164
        score_vectors = np.array(
165
            [
166
                [
167
                    np.sqrt(suggestions.as_vector())
168
                    * src_weight[project_id]
169
                    * len(batch_by_source)
170
                    for suggestions in batch
171
                ]
172
                for project_id, batch in batch_by_source.items()
173
            ],
174
            dtype=np.float32,
175
        ).transpose(1, 2, 0)
176
        prediction = self._model(score_vectors).numpy()
177
        return SuggestionBatch.from_sequence(
178
            [
179
                vector_to_suggestions(row, limit=int(params["limit"]))
180
                for row in prediction
181
            ],
182
            self.project.subjects,
183
        )
184
185
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
186
        self.info("creating NN ensemble model")
187
188
        inputs = Input(shape=(len(self.project.subjects), len(sources)))
189
190
        flat_input = Flatten()(inputs)
191
        drop_input = Dropout(rate=float(self.params["dropout_rate"]))(flat_input)
192
        hidden = Dense(int(self.params["nodes"]), activation="relu")(drop_input)
193
        drop_hidden = Dropout(rate=float(self.params["dropout_rate"]))(hidden)
194
        delta = Dense(
195
            len(self.project.subjects),
196
            kernel_initializer="zeros",
197
            bias_initializer="zeros",
198
        )(drop_hidden)
199
200
        mean = MeanLayer()(inputs)
201
202
        predictions = Add()([mean, delta])
203
204
        self._model = Model(inputs=inputs, outputs=predictions)
205
        self._model.compile(
206
            optimizer=self.params["optimizer"],
207
            loss="binary_crossentropy",
208
            metrics=["top_k_categorical_accuracy"],
209
        )
210
        if "lr" in self.params:
211
            self._model.optimizer.learning_rate.assign(float(self.params["lr"]))
212
213
        summary = []
214
        self._model.summary(print_fn=summary.append)
215
        self.debug("Created model: \n" + "\n".join(summary))
216
217
    def _train(
218
        self,
219
        corpus: DocumentCorpus,
220
        params: dict[str, Any],
221
        jobs: int = 0,
222
    ) -> None:
223
        sources = annif.util.parse_sources(self.params["sources"])
224
        self._create_model(sources)
225
        self._fit_model(
226
            corpus,
227
            epochs=int(params["epochs"]),
228
            lmdb_map_size=int(params["lmdb_map_size"]),
229
            n_jobs=jobs,
230
        )
231
232
    def _corpus_to_vectors(
233
        self,
234
        corpus: DocumentCorpus,
235
        seq: LMDBSequence,
236
        n_jobs: int,
237
    ) -> None:
238
        # pass corpus through all source projects
239
        sources = dict(annif.util.parse_sources(self.params["sources"]))
240
241
        # initialize the source projects before forking, to save memory
242
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
243
        for project_id in sources.keys():
244
            project = self.project.registry.get_project(project_id)
245
            project.initialize(parallel=True)
246
247
        psmap = annif.parallel.ProjectSuggestMap(
248
            self.project.registry,
249
            list(sources.keys()),
250
            backend_params=None,
251
            limit=None,
252
            threshold=0.0,
253
        )
254
255
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
256
257
        self.info("Processing training documents...")
258
        with pool_class(jobs) as pool:
259
            for hits, subject_set in pool.imap_unordered(
260
                psmap.suggest, corpus.documents
261
            ):
262
                doc_scores = []
263
                for project_id, p_hits in hits.items():
264
                    vector = p_hits.as_vector()
265
                    doc_scores.append(
266
                        np.sqrt(vector) * sources[project_id] * len(sources)
267
                    )
268
                score_vector = np.array(doc_scores, dtype=np.float32).transpose()
269
                true_vector = subject_set.as_vector(len(self.project.subjects))
270
                seq.add_sample(score_vector, true_vector)
271
272
    def _open_lmdb(self, cached, lmdb_map_size):
273
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
274
        if not cached and os.path.exists(lmdb_path):
275
            shutil.rmtree(lmdb_path)
276
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
277
278
    def _fit_model(
279
        self,
280
        corpus: DocumentCorpus,
281
        epochs: int,
282
        lmdb_map_size: int,
283
        n_jobs: int = 1,
284
    ) -> None:
285
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
286
        if corpus != "cached":
287
            if corpus.is_empty():
288
                raise NotSupportedException(
289
                    "Cannot train nn_ensemble project with no documents"
290
                )
291
            with env.begin(write=True, buffers=True) as txn:
292
                seq = LMDBSequence(txn, batch_size=32)
293
                self._corpus_to_vectors(corpus, seq, n_jobs)
294
        else:
295
            self.info("Reusing cached training data from previous run.")
296
        # fit the model using a read-only view of the LMDB
297
        self.info("Training neural network model...")
298
        with env.begin(buffers=True) as txn:
299
            seq = LMDBSequence(txn, batch_size=32)
300
            self._model.fit(seq, verbose=True, epochs=epochs)
301
302
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
303
304
    def _learn(
305
        self,
306
        corpus: DocumentCorpus,
307
        params: dict[str, Any],
308
    ) -> None:
309
        self.initialize()
310
        self._fit_model(
311
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
312
        )
313
314
    def get_model_metadata(self, model_filename: str) -> dict | None:
315
        """Read metadata from Keras model files."""
316
317
        try:
318
            with zipfile.ZipFile(model_filename, "r") as zip:
319
                with zip.open("metadata.json") as metadata_file:
320
                    metadata_str = metadata_file.read().decode("utf-8")
321
                    metadata = json.loads(metadata_str)
322
                    return metadata
323
        except Exception:
324
            self.warning(f"Failed to read metadata from {model_filename}")
325
            return None
326