Passed
Pull Request — main (#926)
by Osma
08:40 queued 04:24
created

LMDBSequence.__getitem__()   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 14
rs 9.75
c 0
b 0
f 0
cc 3
nop 2
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from torchmetrics.retrieval import RetrievalNormalizedDCG
21
from tqdm import tqdm
22
23
import annif.corpus
24
import annif.parallel
25
import annif.util
26
from annif.exception import (
27
    NotInitializedException,
28
    NotSupportedException,
29
    OperationFailedException,
30
)
31
from annif.suggestion import SuggestionBatch, vector_to_suggestions
32
33
from . import backend, ensemble
34
35
if TYPE_CHECKING:
36
    from annif.corpus.document import DocumentCorpus
37
38
logger = annif.logger
39
40
41
def idx_to_key(idx: int) -> bytes:
42
    """convert an integer index to a binary key for use in LMDB"""
43
    return b"%08d" % idx
44
45
46
def key_to_idx(key: memoryview | bytes) -> int:
47
    """convert a binary LMDB key to an integer index"""
48
    return int(key)
49
50
51
class LMDBDataset(Dataset):
52
    """A sequence of samples stored in a LMDB database."""
53
54
    def __init__(self, txn):
55
        super().__init__()
56
        self._txn = txn
57
        cursor = txn.cursor()
58
        if cursor.last():
59
            # Counter holds the number of samples in the database
60
            self._counter = key_to_idx(cursor.key()) + 1
61
        else:  # empty database
62
            self._counter = 0
63
64
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
65
        # use zero-padded 8-digit key
66
        key = idx_to_key(self._counter)
67
        self._counter += 1
68
        # convert the sample into a sparse matrix and serialize it as bytes
69
        sample = (csc_matrix(inputs), csr_matrix(targets))
70
        buf = BytesIO()
71
        joblib.dump(sample, buf)
72
        buf.seek(0)
73
        self._txn.put(key, buf.read())
74
75
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
76
        """get a particular sample"""
77
        cursor = self._txn.cursor()
78
        cursor.set_key(idx_to_key(idx))
79
        value = cursor.value()
80
        input_csr, target_csr = joblib.load(BytesIO(value))
81
        input_tensor = torch.from_numpy(input_csr.toarray())
82
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
83
        return input_tensor, target_tensor
84
85
    def __len__(self) -> int:
86
        """return the number of available samples"""
87
        return self._counter
88
89
90
class NNEnsembleModel(nn.Module):
91
    def __init__(
92
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
93
    ):
94
        super().__init__()
95
        self.model_config = {
96
            "input_dim": input_dim,
97
            "hidden_dim": hidden_dim,
98
            "output_dim": output_dim,
99
            "dropout_rate": dropout_rate,
100
        }
101
        self.flatten = nn.Flatten()
102
        self.dropout1 = nn.Dropout(dropout_rate)
103
        self.hidden = nn.Linear(input_dim, hidden_dim)
104
        self.dropout2 = nn.Dropout(dropout_rate)
105
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
106
107
    def forward(self, inputs):
108
        mean = torch.mean(inputs, dim=1)
109
        x = self.flatten(inputs)
110
        x = self.dropout1(x)
111
        x = F.relu(self.hidden(x))
112
        x = self.dropout2(x)
113
        delta = self.delta_layer(x)
114
        return mean + delta
115
116
    def save(self, filepath):
117
        torch.save(
118
            {
119
                "model_state_dict": self.state_dict(),
120
                "model_class": self.__class__.__name__,
121
                "model_config": self.model_config,
122
                "pytorch_version": str(torch.__version__),
123
                "python_version": sys.version,
124
            },
125
            filepath,
126
        )
127
128
    @classmethod
129
    def load(cls, filepath, map_location="cpu"):
130
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
131
        config = checkpoint["model_config"]
132
        model = cls(**config)
133
        model.load_state_dict(checkpoint["model_state_dict"])
134
        model.eval()
135
        return model
136
137
138
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
139
    """Neural network ensemble backend that combines results from multiple
140
    projects"""
141
142
    name = "nn_ensemble"
143
144
    MODEL_FILE = "nn-model.pt"
145
    LMDB_FILE = "nn-train.mdb"
146
147
    DEFAULT_PARAMETERS = {
148
        "nodes": 100,
149
        "dropout_rate": 0.2,
150
        "optimizer": "adam",
151
        "lr": 0.001,
152
        "epochs": 10,
153
        "learn-epochs": 1,
154
        "lmdb_map_size": 1024 * 1024 * 1024,
155
    }
156
157
    # defaults for uninitialized instances
158
    _model = None
159
160
    def initialize(self, parallel: bool = False) -> None:
161
        super().initialize(parallel)
162
        if self._model is not None:
163
            return  # already initialized
164
        if parallel:
165
            # Don't load model just before parallel execution,
166
            # since it won't work after forking worker processes
167
            return
168
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
169
        if not os.path.exists(model_filename):
170
            raise NotInitializedException(
171
                "model file {} not found".format(model_filename),
172
                backend_id=self.backend_id,
173
            )
174
        self.debug("loading model from {}".format(model_filename))
175
        try:
176
            self._model = NNEnsembleModel.load(model_filename)
177
        except Exception as err:
178
            message = (
179
                f"loading model from {model_filename}; "
180
                f'original error message: "{err}"'
181
            )
182
            raise OperationFailedException(message, backend_id=self.backend_id)
183
184
    def _merge_source_batches(
185
        self,
186
        batch_by_source: dict[str, SuggestionBatch],
187
        sources: list[tuple[str, float]],
188
        params: dict[str, Any],
189
    ) -> SuggestionBatch:
190
        src_weight = dict(sources)
191
        score_vectors = np.array(
192
            [
193
                [
194
                    np.sqrt(suggestions.as_vector())
195
                    * src_weight[project_id]
196
                    * len(batch_by_source)
197
                    for suggestions in batch
198
                ]
199
                for project_id, batch in batch_by_source.items()
200
            ],
201
            dtype=np.float32,
202
        )
203
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
204
        with torch.no_grad():
205
            prediction = self._model(score_vector_tensor)
206
        return SuggestionBatch.from_sequence(
207
            [
208
                vector_to_suggestions(row, limit=int(params["limit"]))
209
                for row in prediction
210
            ],
211
            self.project.subjects,
212
        )
213
214
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
215
        self.info("creating NN ensemble model")
216
217
        # Create PyTorch model
218
        input_dim = len(self.project.subjects) * len(sources)
219
        hidden_dim = int(self.params["nodes"])
220
        output_dim = len(self.project.subjects)
221
        dropout_rate = float(self.params["dropout_rate"])
222
223
        self._model = NNEnsembleModel(
224
            input_dim=input_dim,
225
            hidden_dim=hidden_dim,
226
            output_dim=output_dim,
227
            dropout_rate=dropout_rate,
228
        )
229
230
    def _train(
231
        self,
232
        corpus: DocumentCorpus,
233
        params: dict[str, Any],
234
        jobs: int = 0,
235
    ) -> None:
236
        sources = annif.util.parse_sources(self.params["sources"])
237
        self._create_model(sources)
238
        self._fit_model(
239
            corpus,
240
            epochs=int(params["epochs"]),
241
            lmdb_map_size=int(params["lmdb_map_size"]),
242
            n_jobs=jobs,
243
        )
244
245
    def _corpus_to_vectors(
246
        self,
247
        corpus: DocumentCorpus,
248
        seq: LMDBDataset,
249
        n_jobs: int,
250
    ) -> None:
251
        # pass corpus through all source projects
252
        sources = dict(annif.util.parse_sources(self.params["sources"]))
253
254
        # initialize the source projects before forking, to save memory
255
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
256
        for project_id in sources.keys():
257
            project = self.project.registry.get_project(project_id)
258
            project.initialize(parallel=True)
259
260
        psmap = annif.parallel.ProjectSuggestMap(
261
            self.project.registry,
262
            list(sources.keys()),
263
            backend_params=None,
264
            limit=None,
265
            threshold=0.0,
266
        )
267
268
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
269
270
        self.info("Processing training documents...")
271
        with pool_class(jobs) as pool:
272
            for hits, subject_set in pool.imap_unordered(
273
                psmap.suggest, corpus.documents
274
            ):
275
                doc_scores = []
276
                for project_id, p_hits in hits.items():
277
                    vector = p_hits.as_vector()
278
                    doc_scores.append(
279
                        np.sqrt(vector) * sources[project_id] * len(sources)
280
                    )
281
                score_vector = np.array(doc_scores, dtype=np.float32)
282
                true_vector = subject_set.as_vector(len(self.project.subjects))
283
                seq.add_sample(score_vector, true_vector)
284
285
    def _open_lmdb(self, cached, lmdb_map_size):
286
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
287
        if not cached and os.path.exists(lmdb_path):
288
            shutil.rmtree(lmdb_path)
289
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
290
291
    def _fit_model(
292
        self,
293
        corpus: DocumentCorpus,
294
        epochs: int,
295
        lmdb_map_size: int,
296
        n_jobs: int = 1,
297
    ) -> None:
298
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
299
        if corpus != "cached":
300
            if corpus.is_empty():
301
                raise NotSupportedException(
302
                    "Cannot train nn_ensemble project with no documents"
303
                )
304
            with env.begin(write=True, buffers=True) as txn:
305
                seq = LMDBDataset(txn)
306
                self._corpus_to_vectors(corpus, seq, n_jobs)
307
        else:
308
            self.info("Reusing cached training data from previous run.")
309
310
        # fit the model using a read-only view of the LMDB
311
        self.info("Training neural network model...")
312
        with env.begin(buffers=True) as txn:
313
            dataset = LMDBDataset(txn)
314
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
315
316
            # Training loop
317
            optimizer = torch.optim.Adam(
318
                self._model.parameters(), lr=float(self.params["lr"]), weight_decay=0
319
            )
320
            criterion = nn.BCEWithLogitsLoss()
321
            ndcg_metric = RetrievalNormalizedDCG(top_k=None)
322
323
            self._model.train()
324
            for epoch in range(epochs):
325
                ndcg_metric.reset()
326
                total_loss = 0.0
327
                total_samples = 0
328
                tqdm_loader = tqdm(
329
                    dataloader,
330
                    desc=f"Epoch {epoch + 1}/{epochs}",
331
                    postfix={"loss": "0.000"},
332
                )
333
                for inputs, targets in tqdm_loader:
334
                    optimizer.zero_grad()
335
                    outputs = self._model(inputs)
336
                    loss = criterion(outputs, targets)
337
                    loss.backward()
338
                    optimizer.step()
339
340
                    batch_size, n_labels = outputs.shape
341
342
                    # Build indexes; each sample is a separate query for nDCG
343
                    indexes = torch.repeat_interleave(
344
                        torch.arange(batch_size, device=outputs.device), n_labels
345
                    )
346
                    ndcg_metric.update(
347
                        outputs.reshape(-1), targets.reshape(-1), indexes=indexes
348
                    )
349
350
                    # Update loss stats
351
                    total_loss += loss.item() * batch_size
352
                    total_samples += batch_size
353
354
                    # Update progress bar with batch loss
355
                    tqdm_loader.set_postfix(loss=loss.item())
356
357
                epoch_loss = total_loss / total_samples
358
                epoch_ndcg = ndcg_metric.compute().item()
359
                print(
360
                    f"Epoch {epoch + 1}/{epochs} "
361
                    f"- loss: {epoch_loss:.4f} "
362
                    f"- nDCG: {epoch_ndcg:.4f}"
363
                )
364
365
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
366
367
    def _learn(
368
        self,
369
        corpus: DocumentCorpus,
370
        params: dict[str, Any],
371
    ) -> None:
372
        self.initialize()
373
        self._fit_model(
374
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
375
        )
376