Passed
Push — issue631-rest-api-language-det... ( 34c253...1cd800 )
by Osma
04:27
created

NNEnsembleBackend._fit_model()   B

Complexity

Conditions 5

Size

Total Lines 25
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 20
nop 5
dl 0
loc 25
rs 8.9332
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import importlib
7
import json
8
import os.path
9
import shutil
10
import zipfile
11
from io import BytesIO
12
from typing import TYPE_CHECKING, Any
13
14
import joblib
15
import keras.backend as K
16
import lmdb
17
import numpy as np
18
from keras.layers import Add, Dense, Dropout, Flatten, Input, Layer
19
from keras.models import Model
20
from keras.saving import load_model
21
from keras.utils import Sequence
22
from scipy.sparse import csc_matrix, csr_matrix
23
24
import annif.corpus
25
import annif.parallel
26
import annif.util
27
from annif.exception import (
28
    NotInitializedException,
29
    NotSupportedException,
30
    OperationFailedException,
31
)
32
from annif.suggestion import SuggestionBatch, vector_to_suggestions
33
34
from . import backend, ensemble
35
36
if TYPE_CHECKING:
37
    from tensorflow.python.framework.ops import EagerTensor
38
39
    from annif.corpus.document import DocumentCorpus
40
41
logger = annif.logger
42
43
44
def idx_to_key(idx: int) -> bytes:
45
    """convert an integer index to a binary key for use in LMDB"""
46
    return b"%08d" % idx
47
48
49
def key_to_idx(key: memoryview | bytes) -> int:
50
    """convert a binary LMDB key to an integer index"""
51
    return int(key)
52
53
54
class LMDBSequence(Sequence):
55
    """A sequence of samples stored in a LMDB database."""
56
57
    def __init__(self, txn, batch_size):
58
        self._txn = txn
59
        cursor = txn.cursor()
60
        if cursor.last():
61
            # Counter holds the number of samples in the database
62
            self._counter = key_to_idx(cursor.key()) + 1
63
        else:  # empty database
64
            self._counter = 0
65
        self._batch_size = batch_size
66
67
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
68
        # use zero-padded 8-digit key
69
        key = idx_to_key(self._counter)
70
        self._counter += 1
71
        # convert the sample into a sparse matrix and serialize it as bytes
72
        sample = (csc_matrix(inputs), csr_matrix(targets))
73
        buf = BytesIO()
74
        joblib.dump(sample, buf)
75
        buf.seek(0)
76
        self._txn.put(key, buf.read())
77
78
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
79
        """get a particular batch of samples"""
80
        cursor = self._txn.cursor()
81
        first_key = idx * self._batch_size
82
        cursor.set_key(idx_to_key(first_key))
83
        input_arrays = []
84
        target_arrays = []
85
        for key, value in cursor.iternext():
86
            if key_to_idx(key) >= (first_key + self._batch_size):
87
                break
88
            input_csr, target_csr = joblib.load(BytesIO(value))
89
            input_arrays.append(input_csr.toarray())
90
            target_arrays.append(target_csr.toarray().flatten())
91
        return np.array(input_arrays), np.array(target_arrays)
92
93
    def __len__(self) -> int:
94
        """return the number of available batches"""
95
        return int(np.ceil(self._counter / self._batch_size))
96
97
98
class MeanLayer(Layer):
99
    """Custom Keras layer that calculates mean values along the 2nd axis."""
100
101
    def call(self, inputs: EagerTensor) -> EagerTensor:
102
        return K.mean(inputs, axis=2)
103
104
105
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
106
    """Neural network ensemble backend that combines results from multiple
107
    projects"""
108
109
    name = "nn_ensemble"
110
111
    MODEL_FILE = "nn-model.keras"
112
    LMDB_FILE = "nn-train.mdb"
113
114
    DEFAULT_PARAMETERS = {
115
        "nodes": 100,
116
        "dropout_rate": 0.2,
117
        "optimizer": "adam",
118
        "epochs": 10,
119
        "learn-epochs": 1,
120
        "lmdb_map_size": 1024 * 1024 * 1024,
121
    }
122
123
    # defaults for uninitialized instances
124
    _model = None
125
126
    def initialize(self, parallel: bool = False) -> None:
127
        super().initialize(parallel)
128
        if self._model is not None:
129
            return  # already initialized
130
        if parallel:
131
            # Don't load TF model just before parallel execution,
132
            # since it won't work after forking worker processes
133
            return
134
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
135
        if not os.path.exists(model_filename):
136
            raise NotInitializedException(
137
                "model file {} not found".format(model_filename),
138
                backend_id=self.backend_id,
139
            )
140
        self.debug("loading Keras model from {}".format(model_filename))
141
        try:
142
            self._model = load_model(
143
                model_filename, custom_objects={"MeanLayer": MeanLayer}
144
            )
145
        except Exception as err:
146
            metadata = self.get_model_metadata(model_filename)
147
            keras_version = importlib.metadata.version("keras")
148
            message = (
149
                f"loading Keras model from {model_filename}; "
150
                f"model metadata: {metadata}; "
151
                f"you have Keras version {keras_version}. "
152
                f'Original error message: "{err}"'
153
            )
154
            raise OperationFailedException(message, backend_id=self.backend_id)
155
156
    def _merge_source_batches(
157
        self,
158
        batch_by_source: dict[str, SuggestionBatch],
159
        sources: list[tuple[str, float]],
160
        params: dict[str, Any],
161
    ) -> SuggestionBatch:
162
        src_weight = dict(sources)
163
        score_vectors = np.array(
164
            [
165
                [
166
                    np.sqrt(suggestions.as_vector())
167
                    * src_weight[project_id]
168
                    * len(batch_by_source)
169
                    for suggestions in batch
170
                ]
171
                for project_id, batch in batch_by_source.items()
172
            ],
173
            dtype=np.float32,
174
        ).transpose(1, 2, 0)
175
        prediction = self._model(score_vectors).numpy()
176
        return SuggestionBatch.from_sequence(
177
            [
178
                vector_to_suggestions(row, limit=int(params["limit"]))
179
                for row in prediction
180
            ],
181
            self.project.subjects,
182
        )
183
184
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
185
        self.info("creating NN ensemble model")
186
187
        inputs = Input(shape=(len(self.project.subjects), len(sources)))
188
189
        flat_input = Flatten()(inputs)
190
        drop_input = Dropout(rate=float(self.params["dropout_rate"]))(flat_input)
191
        hidden = Dense(int(self.params["nodes"]), activation="relu")(drop_input)
192
        drop_hidden = Dropout(rate=float(self.params["dropout_rate"]))(hidden)
193
        delta = Dense(
194
            len(self.project.subjects),
195
            kernel_initializer="zeros",
196
            bias_initializer="zeros",
197
        )(drop_hidden)
198
199
        mean = MeanLayer()(inputs)
200
201
        predictions = Add()([mean, delta])
202
203
        self._model = Model(inputs=inputs, outputs=predictions)
204
        self._model.compile(
205
            optimizer=self.params["optimizer"],
206
            loss="binary_crossentropy",
207
            metrics=["top_k_categorical_accuracy"],
208
        )
209
        if "lr" in self.params:
210
            self._model.optimizer.learning_rate.assign(float(self.params["lr"]))
211
212
        summary = []
213
        self._model.summary(print_fn=summary.append)
214
        self.debug("Created model: \n" + "\n".join(summary))
215
216
    def _train(
217
        self,
218
        corpus: DocumentCorpus,
219
        params: dict[str, Any],
220
        jobs: int = 0,
221
    ) -> None:
222
        sources = annif.util.parse_sources(self.params["sources"])
223
        self._create_model(sources)
224
        self._fit_model(
225
            corpus,
226
            epochs=int(params["epochs"]),
227
            lmdb_map_size=int(params["lmdb_map_size"]),
228
            n_jobs=jobs,
229
        )
230
231
    def _corpus_to_vectors(
232
        self,
233
        corpus: DocumentCorpus,
234
        seq: LMDBSequence,
235
        n_jobs: int,
236
    ) -> None:
237
        # pass corpus through all source projects
238
        sources = dict(annif.util.parse_sources(self.params["sources"]))
239
240
        # initialize the source projects before forking, to save memory
241
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
242
        for project_id in sources.keys():
243
            project = self.project.registry.get_project(project_id)
244
            project.initialize(parallel=True)
245
246
        psmap = annif.parallel.ProjectSuggestMap(
247
            self.project.registry,
248
            list(sources.keys()),
249
            backend_params=None,
250
            limit=None,
251
            threshold=0.0,
252
        )
253
254
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
255
256
        self.info("Processing training documents...")
257
        with pool_class(jobs) as pool:
258
            for hits, subject_set in pool.imap_unordered(
259
                psmap.suggest, corpus.documents
260
            ):
261
                doc_scores = []
262
                for project_id, p_hits in hits.items():
263
                    vector = p_hits.as_vector()
264
                    doc_scores.append(
265
                        np.sqrt(vector) * sources[project_id] * len(sources)
266
                    )
267
                score_vector = np.array(doc_scores, dtype=np.float32).transpose()
268
                true_vector = subject_set.as_vector(len(self.project.subjects))
269
                seq.add_sample(score_vector, true_vector)
270
271
    def _open_lmdb(self, cached, lmdb_map_size):
272
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
273
        if not cached and os.path.exists(lmdb_path):
274
            shutil.rmtree(lmdb_path)
275
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True)
276
277
    def _fit_model(
278
        self,
279
        corpus: DocumentCorpus,
280
        epochs: int,
281
        lmdb_map_size: int,
282
        n_jobs: int = 1,
283
    ) -> None:
284
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
285
        if corpus != "cached":
286
            if corpus.is_empty():
287
                raise NotSupportedException(
288
                    "Cannot train nn_ensemble project with no documents"
289
                )
290
            with env.begin(write=True, buffers=True) as txn:
291
                seq = LMDBSequence(txn, batch_size=32)
292
                self._corpus_to_vectors(corpus, seq, n_jobs)
293
        else:
294
            self.info("Reusing cached training data from previous run.")
295
        # fit the model using a read-only view of the LMDB
296
        self.info("Training neural network model...")
297
        with env.begin(buffers=True) as txn:
298
            seq = LMDBSequence(txn, batch_size=32)
299
            self._model.fit(seq, verbose=True, epochs=epochs)
300
301
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
302
303
    def _learn(
304
        self,
305
        corpus: DocumentCorpus,
306
        params: dict[str, Any],
307
    ) -> None:
308
        self.initialize()
309
        self._fit_model(
310
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
311
        )
312
313
    def get_model_metadata(self, model_filename: str) -> dict | None:
314
        """Read metadata from Keras model files."""
315
316
        try:
317
            with zipfile.ZipFile(model_filename, "r") as zip:
318
                with zip.open("metadata.json") as metadata_file:
319
                    metadata_str = metadata_file.read().decode("utf-8")
320
                    metadata = json.loads(metadata_str)
321
                    return metadata
322
        except Exception:
323
            self.warning(f"Failed to read metadata from {model_filename}")
324
            return None
325