NNEnsembleBackend._train()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 13
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 12
nop 4
dl 0
loc 13
rs 9.8
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
21
import annif.corpus
22
import annif.parallel
23
import annif.util
24
from annif.exception import (
25
    NotInitializedException,
26
    NotSupportedException,
27
    OperationFailedException,
28
)
29
from annif.suggestion import SuggestionBatch, vector_to_suggestions
30
31
from . import backend, ensemble
32
33
if TYPE_CHECKING:
34
    from annif.corpus.document import DocumentCorpus
35
36
logger = annif.logger
37
38
39
def idx_to_key(idx: int) -> bytes:
40
    """convert an integer index to a binary key for use in LMDB"""
41
    return b"%08d" % idx
42
43
44
def key_to_idx(key: memoryview | bytes) -> int:
45
    """convert a binary LMDB key to an integer index"""
46
    return int(key)
47
48
49
class LMDBDataset(Dataset):
50
    """A sequence of samples stored in a LMDB database."""
51
52
    def __init__(self, txn):
53
        super().__init__()
54
        self._txn = txn
55
        cursor = txn.cursor()
56
        if cursor.last():
57
            # Counter holds the number of samples in the database
58
            self._counter = key_to_idx(cursor.key()) + 1
59
        else:  # empty database
60
            self._counter = 0
61
62
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
63
        # use zero-padded 8-digit key
64
        key = idx_to_key(self._counter)
65
        self._counter += 1
66
        # convert the sample into a sparse matrix and serialize it as bytes
67
        sample = (csc_matrix(inputs), csr_matrix(targets))
68
        buf = BytesIO()
69
        joblib.dump(sample, buf)
70
        buf.seek(0)
71
        self._txn.put(key, buf.read())
72
73
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
74
        """get a particular sample"""
75
        cursor = self._txn.cursor()
76
        cursor.set_key(idx_to_key(idx))
77
        value = cursor.value()
78
        input_csr, target_csr = joblib.load(BytesIO(value))
79
        input_tensor = torch.from_numpy(input_csr.toarray())
80
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
81
        return input_tensor, target_tensor
82
83
    def __len__(self) -> int:
84
        """return the number of available samples"""
85
        return self._counter
86
87
88
class NNEnsembleModel(nn.Module):
89
    def __init__(
90
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
91
    ):
92
        super().__init__()
93
        self.model_config = {
94
            "input_dim": input_dim,
95
            "hidden_dim": hidden_dim,
96
            "output_dim": output_dim,
97
            "dropout_rate": dropout_rate,
98
        }
99
        self.flatten = nn.Flatten()
100
        self.dropout1 = nn.Dropout(dropout_rate)
101
        self.hidden = nn.Linear(input_dim, hidden_dim)
102
        self.dropout2 = nn.Dropout(dropout_rate)
103
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
104
105
    def forward(self, inputs):
106
        mean = torch.mean(inputs, dim=1)
107
        x = self.flatten(inputs)
108
        x = self.dropout1(x)
109
        x = F.relu(self.hidden(x))
110
        x = self.dropout2(x)
111
        delta = self.delta_layer(x)
112
        return mean + delta
113
114
    def save(self, filepath):
115
        torch.save(
116
            {
117
                "model_state_dict": self.state_dict(),
118
                "model_class": self.__class__.__name__,
119
                "model_config": self.model_config,
120
                "pytorch_version": str(torch.__version__),
121
                "python_version": sys.version,
122
            },
123
            filepath,
124
        )
125
126
    @classmethod
127
    def load(cls, filepath, map_location="cpu"):
128
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
129
        config = checkpoint["model_config"]
130
        model = cls(**config)
131
        model.load_state_dict(checkpoint["model_state_dict"])
132
        model.eval()
133
        return model
134
135
136
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
137
    """Neural network ensemble backend that combines results from multiple
138
    projects"""
139
140
    name = "nn_ensemble"
141
142
    MODEL_FILE = "nn-model.pt"
143
    LMDB_FILE = "nn-train.mdb"
144
145
    DEFAULT_PARAMETERS = {
146
        "nodes": 100,
147
        "dropout_rate": 0.2,
148
        "optimizer": "adam",
149
        "lr": 0.001,
150
        "epochs": 10,
151
        "learn-epochs": 1,
152
        "lmdb_map_size": 1024 * 1024 * 1024,
153
    }
154
155
    # defaults for uninitialized instances
156
    _model = None
157
158
    def initialize(self, parallel: bool = False) -> None:
159
        super().initialize(parallel)
160
        if self._model is not None:
161
            return  # already initialized
162
        if parallel:
163
            # Don't load model just before parallel execution,
164
            # since it won't work after forking worker processes
165
            return
166
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
167
        if not os.path.exists(model_filename):
168
            raise NotInitializedException(
169
                "model file {} not found".format(model_filename),
170
                backend_id=self.backend_id,
171
            )
172
        self.debug("loading model from {}".format(model_filename))
173
        try:
174
            self._model = NNEnsembleModel.load(model_filename)
175
        except Exception as err:
176
            message = (
177
                f"loading model from {model_filename}; "
178
                f'original error message: "{err}"'
179
            )
180
            raise OperationFailedException(message, backend_id=self.backend_id)
181
182
    def _merge_source_batches(
183
        self,
184
        batch_by_source: dict[str, SuggestionBatch],
185
        sources: list[tuple[str, float]],
186
        params: dict[str, Any],
187
    ) -> SuggestionBatch:
188
        src_weight = dict(sources)
189
        score_vectors = np.array(
190
            [
191
                [
192
                    np.sqrt(suggestions.as_vector())
193
                    * src_weight[project_id]
194
                    * len(batch_by_source)
195
                    for suggestions in batch
196
                ]
197
                for project_id, batch in batch_by_source.items()
198
            ],
199
            dtype=np.float32,
200
        )
201
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
202
        with torch.no_grad():
203
            prediction = self._model(score_vector_tensor)
204
        return SuggestionBatch.from_sequence(
205
            [
206
                vector_to_suggestions(row, limit=int(params["limit"]))
207
                for row in prediction
208
            ],
209
            self.project.subjects,
210
        )
211
212
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
213
        self.info("creating NN ensemble model")
214
215
        # Create PyTorch model
216
        input_dim = len(self.project.subjects) * len(sources)
217
        hidden_dim = int(self.params["nodes"])
218
        output_dim = len(self.project.subjects)
219
        dropout_rate = float(self.params["dropout_rate"])
220
221
        self._model = NNEnsembleModel(
222
            input_dim=input_dim,
223
            hidden_dim=hidden_dim,
224
            output_dim=output_dim,
225
            dropout_rate=dropout_rate,
226
        )
227
228
    #        summary = []
229
    #        self._model.summary(print_fn=summary.append)
230
    #        self.debug("Created model: \n" + "\n".join(summary))
231
232
    def _train(
233
        self,
234
        corpus: DocumentCorpus,
235
        params: dict[str, Any],
236
        jobs: int = 0,
237
    ) -> None:
238
        sources = annif.util.parse_sources(self.params["sources"])
239
        self._create_model(sources)
240
        self._fit_model(
241
            corpus,
242
            epochs=int(params["epochs"]),
243
            lmdb_map_size=int(params["lmdb_map_size"]),
244
            n_jobs=jobs,
245
        )
246
247
    def _corpus_to_vectors(
248
        self,
249
        corpus: DocumentCorpus,
250
        seq: LMDBDataset,
251
        n_jobs: int,
252
    ) -> None:
253
        # pass corpus through all source projects
254
        sources = dict(annif.util.parse_sources(self.params["sources"]))
255
256
        # initialize the source projects before forking, to save memory
257
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
258
        for project_id in sources.keys():
259
            project = self.project.registry.get_project(project_id)
260
            project.initialize(parallel=True)
261
262
        psmap = annif.parallel.ProjectSuggestMap(
263
            self.project.registry,
264
            list(sources.keys()),
265
            backend_params=None,
266
            limit=None,
267
            threshold=0.0,
268
        )
269
270
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
271
272
        self.info("Processing training documents...")
273
        with pool_class(jobs) as pool:
274
            for hits, subject_set in pool.imap_unordered(
275
                psmap.suggest, corpus.documents
276
            ):
277
                doc_scores = []
278
                for project_id, p_hits in hits.items():
279
                    vector = p_hits.as_vector()
280
                    doc_scores.append(
281
                        np.sqrt(vector) * sources[project_id] * len(sources)
282
                    )
283
                score_vector = np.array(doc_scores, dtype=np.float32)
284
                true_vector = subject_set.as_vector(len(self.project.subjects))
285
                seq.add_sample(score_vector, true_vector)
286
287
    def _open_lmdb(self, cached, lmdb_map_size):
288
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
289
        if not cached and os.path.exists(lmdb_path):
290
            shutil.rmtree(lmdb_path)
291
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
292
293
    def _fit_model(
294
        self,
295
        corpus: DocumentCorpus,
296
        epochs: int,
297
        lmdb_map_size: int,
298
        n_jobs: int = 1,
299
    ) -> None:
300
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
301
        if corpus != "cached":
302
            if corpus.is_empty():
303
                raise NotSupportedException(
304
                    "Cannot train nn_ensemble project with no documents"
305
                )
306
            with env.begin(write=True, buffers=True) as txn:
307
                seq = LMDBDataset(txn)
308
                self._corpus_to_vectors(corpus, seq, n_jobs)
309
        else:
310
            self.info("Reusing cached training data from previous run.")
311
312
        # fit the model using a read-only view of the LMDB
313
        self.info("Training neural network model...")
314
        with env.begin(buffers=True) as txn:
315
            dataset = LMDBDataset(txn)
316
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
317
318
            # Training loop
319
            optimizer = torch.optim.Adam(
320
                self._model.parameters(), lr=float(self.params["lr"])
321
            )
322
            criterion = nn.BCEWithLogitsLoss()
323
324
            self._model.train()
325
            for epoch in range(epochs):
326
                for inputs, targets in dataloader:
327
                    optimizer.zero_grad()
328
                    outputs = self._model(inputs)
329
                    loss = criterion(outputs, targets)
330
                    loss.backward()
331
                    optimizer.step()
332
333
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
334
335
    def _learn(
336
        self,
337
        corpus: DocumentCorpus,
338
        params: dict[str, Any],
339
    ) -> None:
340
        self.initialize()
341
        self._fit_model(
342
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
343
        )
344