Passed
Push — testing-on-windows-and-macos ( 782857...ea99ad )
by Juho
04:06
created

NNEnsembleBackend.get_model_metadata()   A

Complexity

Conditions 4

Size

Total Lines 12
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 10
nop 2
dl 0
loc 12
rs 9.9
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import importlib
7
import json
8
import os.path
9
import shutil
10
import zipfile
11
from io import BytesIO
12
from typing import TYPE_CHECKING, Any
13
14
import joblib
15
import keras
16
import lmdb
17
import numpy as np
18
from keras.layers import Add, Dense, Dropout, Flatten, Input, Layer
19
from keras.models import Model
20
from keras.saving import load_model
21
from keras.utils import Sequence
22
from scipy.sparse import csc_matrix, csr_matrix
23
24
import annif.corpus
25
import annif.parallel
26
import annif.util
27
from annif.exception import (
28
    NotInitializedException,
29
    NotSupportedException,
30
    OperationFailedException,
31
)
32
from annif.suggestion import SuggestionBatch, vector_to_suggestions
33
34
from . import backend, ensemble
35
36
if TYPE_CHECKING:
37
    from tensorflow.python.framework.ops import EagerTensor
38
39
    from annif.corpus.document import DocumentCorpus
40
41
logger = annif.logger
42
43
44
def idx_to_key(idx: int) -> bytes:
45
    """convert an integer index to a binary key for use in LMDB"""
46
    return b"%08d" % idx
47
48
49
def key_to_idx(key: memoryview | bytes) -> int:
50
    """convert a binary LMDB key to an integer index"""
51
    return int(key)
52
53
54
class LMDBSequence(Sequence):
55
    """A sequence of samples stored in a LMDB database."""
56
57
    def __init__(self, txn, batch_size):
58
        self._txn = txn
59
        cursor = txn.cursor()
60
        if cursor.last():
61
            # Counter holds the number of samples in the database
62
            self._counter = key_to_idx(cursor.key()) + 1
63
        else:  # empty database
64
            self._counter = 0
65
        self._batch_size = batch_size
66
67
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
68
        # use zero-padded 8-digit key
69
        key = idx_to_key(self._counter)
70
        self._counter += 1
71
        # convert the sample into a sparse matrix and serialize it as bytes
72
        sample = (csc_matrix(inputs), csr_matrix(targets))
73
        buf = BytesIO()
74
        joblib.dump(sample, buf)
75
        buf.seek(0)
76
        self._txn.put(key, buf.read())
77
78
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
79
        """get a particular batch of samples"""
80
        cursor = self._txn.cursor()
81
        first_key = idx * self._batch_size
82
        cursor.set_key(idx_to_key(first_key))
83
        input_arrays = []
84
        target_arrays = []
85
        for key, value in cursor.iternext():
86
            if key_to_idx(key) >= (first_key + self._batch_size):
87
                break
88
            input_csr, target_csr = joblib.load(BytesIO(value))
89
            input_arrays.append(input_csr.toarray())
90
            target_arrays.append(target_csr.toarray().flatten())
91
        return np.array(input_arrays), np.array(target_arrays)
92
93
    def __len__(self) -> int:
94
        """return the number of available batches"""
95
        return int(np.ceil(self._counter / self._batch_size))
96
97
98
class MeanLayer(Layer):
99
    """Custom Keras layer that calculates mean values along the 2nd axis."""
100
101
    def call(self, inputs: EagerTensor) -> EagerTensor:
102
        return keras.ops.mean(inputs, axis=2)
103
104
105
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
106
    """Neural network ensemble backend that combines results from multiple
107
    projects"""
108
109
    name = "nn_ensemble"
110
111
    MODEL_FILE = "nn-model.keras"
112
    LMDB_FILE = "nn-train.mdb"
113
114
    DEFAULT_PARAMETERS = {
115
        "nodes": 100,
116
        "dropout_rate": 0.2,
117
        "optimizer": "adam",
118
        "epochs": 10,
119
        "learn-epochs": 1,
120
        "lmdb_map_size": 1024 * 1024 * 1024,
121
    }
122
123
    # defaults for uninitialized instances
124
    _model = None
125
126
    def initialize(self, parallel: bool = False) -> None:
127
        super().initialize(parallel)
128
        if self._model is not None:
129
            return  # already initialized
130
        if parallel:
131
            # Don't load TF model just before parallel execution,
132
            # since it won't work after forking worker processes
133
            return
134
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
135
        if not os.path.exists(model_filename):
136
            raise NotInitializedException(
137
                "model file {} not found".format(model_filename),
138
                backend_id=self.backend_id,
139
            )
140
        self.debug("loading Keras model from {}".format(model_filename))
141
        try:
142
            self._model = load_model(
143
                model_filename, custom_objects={"MeanLayer": MeanLayer}
144
            )
145
        except Exception as err:
146
            metadata = self.get_model_metadata(model_filename)
147
            keras_version = importlib.metadata.version("keras")
148
            message = (
149
                f"loading Keras model from {model_filename}; "
150
                f"model metadata: {metadata}; "
151
                f"you have Keras version {keras_version}. "
152
                f'Original error message: "{err}"'
153
            )
154
            raise OperationFailedException(message, backend_id=self.backend_id)
155
156
    def _merge_source_batches(
157
        self,
158
        batch_by_source: dict[str, SuggestionBatch],
159
        sources: list[tuple[str, float]],
160
        params: dict[str, Any],
161
    ) -> SuggestionBatch:
162
        src_weight = dict(sources)
163
        score_vectors = np.array(
164
            [
165
                [
166
                    np.sqrt(suggestions.as_vector())
167
                    * src_weight[project_id]
168
                    * len(batch_by_source)
169
                    for suggestions in batch
170
                ]
171
                for project_id, batch in batch_by_source.items()
172
            ],
173
            dtype=np.float32,
174
        ).transpose(1, 2, 0)
175
        prediction = self._model(score_vectors).numpy()
176
        return SuggestionBatch.from_sequence(
177
            [
178
                vector_to_suggestions(row, limit=int(params["limit"]))
179
                for row in prediction
180
            ],
181
            self.project.subjects,
182
        )
183
184
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
185
        self.info("creating NN ensemble model")
186
187
        inputs = Input(shape=(len(self.project.subjects), len(sources)))
188
189
        flat_input = Flatten()(inputs)
190
        drop_input = Dropout(rate=float(self.params["dropout_rate"]))(flat_input)
191
        hidden = Dense(int(self.params["nodes"]), activation="relu")(drop_input)
192
        drop_hidden = Dropout(rate=float(self.params["dropout_rate"]))(hidden)
193
        delta = Dense(
194
            len(self.project.subjects),
195
            kernel_initializer="zeros",
196
            bias_initializer="zeros",
197
        )(drop_hidden)
198
199
        mean = MeanLayer()(inputs)
200
201
        predictions = Add()([mean, delta])
202
203
        self._model = Model(inputs=inputs, outputs=predictions)
204
        self._model.compile(
205
            optimizer=self.params["optimizer"],
206
            loss="binary_crossentropy",
207
            metrics=["top_k_categorical_accuracy"],
208
        )
209
        if "lr" in self.params:
210
            self._model.optimizer.learning_rate.assign(float(self.params["lr"]))
211
212
        summary = []
213
        self._model.summary(print_fn=summary.append)
214
        self.debug("Created model: \n" + "\n".join(summary))
215
216
    def _train(
217
        self,
218
        corpus: DocumentCorpus,
219
        params: dict[str, Any],
220
        jobs: int = 0,
221
    ) -> None:
222
        sources = annif.util.parse_sources(self.params["sources"])
223
        self._create_model(sources)
224
        self._fit_model(
225
            corpus,
226
            epochs=int(params["epochs"]),
227
            lmdb_map_size=int(params["lmdb_map_size"]),
228
            n_jobs=jobs,
229
        )
230
231
    def _corpus_to_vectors(
232
        self,
233
        corpus: DocumentCorpus,
234
        seq: LMDBSequence,
235
        n_jobs: int,
236
    ) -> None:
237
        # pass corpus through all source projects
238
        sources = dict(annif.util.parse_sources(self.params["sources"]))
239
240
        # initialize the source projects before forking, to save memory
241
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
242
        for project_id in sources.keys():
243
            project = self.project.registry.get_project(project_id)
244
            project.initialize(parallel=True)
245
246
        psmap = annif.parallel.ProjectSuggestMap(
247
            self.project.registry,
248
            list(sources.keys()),
249
            backend_params=None,
250
            limit=None,
251
            threshold=0.0,
252
        )
253
254
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
255
256
        self.info("Processing training documents...")
257
        with pool_class(jobs) as pool:
258
            for hits, subject_set in pool.imap_unordered(
259
                psmap.suggest, corpus.documents
260
            ):
261
                doc_scores = []
262
                for project_id, p_hits in hits.items():
263
                    vector = p_hits.as_vector()
264
                    doc_scores.append(
265
                        np.sqrt(vector) * sources[project_id] * len(sources)
266
                    )
267
                score_vector = np.array(doc_scores, dtype=np.float32).transpose()
268
                true_vector = subject_set.as_vector(len(self.project.subjects))
269
                seq.add_sample(score_vector, true_vector)
270
271
    def _open_lmdb(self, cached, lmdb_map_size):
272
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
273
        if not cached and os.path.exists(lmdb_path):
274
            shutil.rmtree(lmdb_path)
275
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
276
277
    def _fit_model(
278
        self,
279
        corpus: DocumentCorpus,
280
        epochs: int,
281
        lmdb_map_size: int,
282
        n_jobs: int = 1,
283
    ) -> None:
284
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
285
        if corpus != "cached":
286
            if corpus.is_empty():
287
                raise NotSupportedException(
288
                    "Cannot train nn_ensemble project with no documents"
289
                )
290
            with env.begin(write=True, buffers=True) as txn:
291
                seq = LMDBSequence(txn, batch_size=32)
292
                self._corpus_to_vectors(corpus, seq, n_jobs)
293
        else:
294
            self.info("Reusing cached training data from previous run.")
295
        # fit the model using a read-only view of the LMDB
296
        self.info("Training neural network model...")
297
        with env.begin(buffers=True) as txn:
298
            seq = LMDBSequence(txn, batch_size=32)
299
            self._model.fit(seq, verbose=True, epochs=epochs)
300
301
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
302
303
    def _learn(
304
        self,
305
        corpus: DocumentCorpus,
306
        params: dict[str, Any],
307
    ) -> None:
308
        self.initialize()
309
        self._fit_model(
310
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
311
        )
312
313
    def get_model_metadata(self, model_filename: str) -> dict | None:
314
        """Read metadata from Keras model files."""
315
316
        try:
317
            with zipfile.ZipFile(model_filename, "r") as zip:
318
                with zip.open("metadata.json") as metadata_file:
319
                    metadata_str = metadata_file.read().decode("utf-8")
320
                    metadata = json.loads(metadata_str)
321
                    return metadata
322
        except Exception:
323
            self.warning(f"Failed to read metadata from {model_filename}")
324
            return None
325