Passed
Pull Request — main (#926)
by Osma
07:22 queued 03:55
created

annif.backend.nn_ensemble.LMDBDataset.add_sample()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 3
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from tqdm import tqdm
21
22
import annif.corpus
23
import annif.parallel
24
import annif.util
25
from annif.exception import (
26
    NotInitializedException,
27
    NotSupportedException,
28
    OperationFailedException,
29
)
30
from annif.suggestion import SuggestionBatch, vector_to_suggestions
31
32
from . import backend, ensemble
33
34
if TYPE_CHECKING:
35
    from annif.corpus.document import DocumentCorpus
36
37
logger = annif.logger
38
39
40
def idx_to_key(idx: int) -> bytes:
41
    """convert an integer index to a binary key for use in LMDB"""
42
    return b"%08d" % idx
43
44
45
def key_to_idx(key: memoryview | bytes) -> int:
46
    """convert a binary LMDB key to an integer index"""
47
    return int(key)
48
49
50
class LMDBDataset(Dataset):
51
    """A sequence of samples stored in a LMDB database."""
52
53
    def __init__(self, txn):
54
        super().__init__()
55
        self._txn = txn
56
        cursor = txn.cursor()
57
        if cursor.last():
58
            # Counter holds the number of samples in the database
59
            self._counter = key_to_idx(cursor.key()) + 1
60
        else:  # empty database
61
            self._counter = 0
62
63
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
64
        # use zero-padded 8-digit key
65
        key = idx_to_key(self._counter)
66
        self._counter += 1
67
        # convert the sample into a sparse matrix and serialize it as bytes
68
        sample = (csc_matrix(inputs), csr_matrix(targets))
69
        buf = BytesIO()
70
        joblib.dump(sample, buf)
71
        buf.seek(0)
72
        self._txn.put(key, buf.read())
73
74
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
75
        """get a particular sample"""
76
        cursor = self._txn.cursor()
77
        cursor.set_key(idx_to_key(idx))
78
        value = cursor.value()
79
        input_csr, target_csr = joblib.load(BytesIO(value))
80
        input_tensor = torch.from_numpy(input_csr.toarray())
81
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
82
        return input_tensor, target_tensor
83
84
    def __len__(self) -> int:
85
        """return the number of available samples"""
86
        return self._counter
87
88
89
class NNEnsembleModel(nn.Module):
90
    def __init__(
91
        self,
92
        source_dim: int,
93
        input_dim: int,
94
        hidden_dim: int,
95
        output_dim: int,
96
        dropout_rate: float,
97
    ):
98
        super().__init__()
99
        self.model_config = {
100
            "source_dim": source_dim,
101
            "input_dim": input_dim,
102
            "hidden_dim": hidden_dim,
103
            "output_dim": output_dim,
104
            "dropout_rate": dropout_rate,
105
        }
106
        self.conv = nn.Conv1d(source_dim, 1, 1, bias=False)
107
        self.flatten = nn.Flatten()
108
        self.dropout1 = nn.Dropout(dropout_rate)
109
        self.hidden = nn.Linear(input_dim * source_dim, hidden_dim)
110
        self.dropout2 = nn.Dropout(dropout_rate)
111
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
112
        self.reset_parameters()
113
114
    def reset_parameters(self):
115
        self.conv.weight.data.fill_(1 / self.model_config["source_dim"])
116
        nn.init.zeros_(self.delta_layer.weight)
117
        nn.init.zeros_(self.delta_layer.bias)
118
119
    def forward(self, inputs):
120
        mean = self.conv(inputs)[:, 0, :]
121
        x = self.flatten(inputs)
122
        x = self.dropout1(x)
123
        x = F.relu(self.hidden(x))
124
        x = self.dropout2(x)
125
        delta = self.delta_layer(x)
126
        corrected = mean + delta
127
        return torch.clamp(corrected, min=0.0, max=1.0)
128
129
    def save(self, filepath):
130
        torch.save(
131
            {
132
                "model_state_dict": self.state_dict(),
133
                "model_class": self.__class__.__name__,
134
                "model_config": self.model_config,
135
                "pytorch_version": str(torch.__version__),
136
                "python_version": sys.version,
137
            },
138
            filepath,
139
        )
140
141
    @classmethod
142
    def load(cls, filepath, map_location="cpu"):
143
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
144
        config = checkpoint["model_config"]
145
        model = cls(**config)
146
        model.load_state_dict(checkpoint["model_state_dict"])
147
        model.eval()
148
        return model
149
150
151
@torch.no_grad()
152
def ndcg_batch(preds: torch.Tensor, targets: torch.Tensor):
153
    """
154
    preds:   (B, N) float
155
    targets: (B, N) {0,1}
156
157
    Returns: mean nDCG across the batch
158
    """
159
    sorted_idx = torch.argsort(preds, dim=1, descending=True)
160
    sorted_targets = torch.gather(targets, 1, sorted_idx)
161
162
    L = sorted_targets.size(1)
163
    ranks = torch.arange(1, L + 1, device=preds.device)
164
    discounts = 1.0 / torch.log2(ranks + 1)
165
166
    dcg = (sorted_targets * discounts).sum(dim=1)
167
168
    ideal_sorted = torch.sort(targets, dim=1, descending=True).values[:, :L]
169
    idcg = (ideal_sorted * discounts).sum(dim=1)
170
171
    ndcg = dcg / torch.clamp(idcg, min=1e-8)
172
173
    return ndcg.mean()
174
175
176
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
177
    """Neural network ensemble backend that combines results from multiple
178
    projects"""
179
180
    name = "nn_ensemble"
181
182
    MODEL_FILE = "nn-model.pt"
183
    LMDB_FILE = "nn-train.mdb"
184
185
    DEFAULT_PARAMETERS = {
186
        "nodes": 100,
187
        "dropout_rate": 0.2,
188
        "lr": 0.001,
189
        "epochs": 10,
190
        "learn-epochs": 1,
191
        "lmdb_map_size": 1024 * 1024 * 1024,
192
    }
193
194
    # defaults for uninitialized instances
195
    _model = None
196
197
    def initialize(self, parallel: bool = False) -> None:
198
        super().initialize(parallel)
199
        if self._model is not None:
200
            return  # already initialized
201
        if parallel:
202
            # Don't load model just before parallel execution,
203
            # since it won't work after forking worker processes
204
            return
205
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
206
        if not os.path.exists(model_filename):
207
            raise NotInitializedException(
208
                "model file {} not found".format(model_filename),
209
                backend_id=self.backend_id,
210
            )
211
        self.debug("loading model from {}".format(model_filename))
212
        try:
213
            self._model = NNEnsembleModel.load(model_filename)
214
        except Exception as err:
215
            message = (
216
                f"loading model from {model_filename}; "
217
                f'original error message: "{err}"'
218
            )
219
            raise OperationFailedException(message, backend_id=self.backend_id)
220
221
    def _merge_source_batches(
222
        self,
223
        batch_by_source: dict[str, SuggestionBatch],
224
        sources: list[tuple[str, float]],
225
        params: dict[str, Any],
226
    ) -> SuggestionBatch:
227
        src_weight = dict(sources)
228
        score_vectors = np.array(
229
            [
230
                [
231
                    np.sqrt(suggestions.as_vector())
232
                    * src_weight[project_id]
233
                    * len(batch_by_source)
234
                    for suggestions in batch
235
                ]
236
                for project_id, batch in batch_by_source.items()
237
            ],
238
            dtype=np.float32,
239
        )
240
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
241
        with torch.no_grad():
242
            prediction = self._model(score_vector_tensor)
243
        return SuggestionBatch.from_sequence(
244
            [
245
                vector_to_suggestions(row, limit=int(params["limit"]))
246
                for row in prediction
247
            ],
248
            self.project.subjects,
249
        )
250
251
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
252
        self.info("creating NN ensemble model")
253
254
        # Create PyTorch model
255
        source_dim = len(sources)
256
        input_dim = len(self.project.subjects)
257
        hidden_dim = int(self.params["nodes"])
258
        output_dim = len(self.project.subjects)
259
        dropout_rate = float(self.params["dropout_rate"])
260
261
        self._model = NNEnsembleModel(
262
            source_dim=source_dim,
263
            input_dim=input_dim,
264
            hidden_dim=hidden_dim,
265
            output_dim=output_dim,
266
            dropout_rate=dropout_rate,
267
        )
268
269
    def _train(
270
        self,
271
        corpus: DocumentCorpus,
272
        params: dict[str, Any],
273
        jobs: int = 0,
274
    ) -> None:
275
        sources = annif.util.parse_sources(self.params["sources"])
276
        self._create_model(sources)
277
        self._fit_model(
278
            corpus,
279
            epochs=int(params["epochs"]),
280
            lmdb_map_size=int(params["lmdb_map_size"]),
281
            n_jobs=jobs,
282
        )
283
284
    def _corpus_to_vectors(
285
        self,
286
        corpus: DocumentCorpus,
287
        seq: LMDBDataset,
288
        n_jobs: int,
289
    ) -> None:
290
        # pass corpus through all source projects
291
        sources = dict(annif.util.parse_sources(self.params["sources"]))
292
293
        # initialize the source projects before forking, to save memory
294
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
295
        for project_id in sources.keys():
296
            project = self.project.registry.get_project(project_id)
297
            project.initialize(parallel=True)
298
299
        psmap = annif.parallel.ProjectSuggestMap(
300
            self.project.registry,
301
            list(sources.keys()),
302
            backend_params=None,
303
            limit=None,
304
            threshold=0.0,
305
        )
306
307
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
308
309
        self.info("Processing training documents...")
310
        with pool_class(jobs) as pool:
311
            for hits, subject_set in pool.imap_unordered(
312
                psmap.suggest, corpus.documents
313
            ):
314
                doc_scores = []
315
                for project_id, p_hits in hits.items():
316
                    vector = p_hits.as_vector()
317
                    doc_scores.append(
318
                        np.sqrt(vector) * sources[project_id] * len(sources)
319
                    )
320
                score_vector = np.array(doc_scores, dtype=np.float32)
321
                true_vector = subject_set.as_vector(len(self.project.subjects))
322
                seq.add_sample(score_vector, true_vector)
323
324
    def _open_lmdb(self, cached, lmdb_map_size):
325
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
326
        if not cached and os.path.exists(lmdb_path):
327
            shutil.rmtree(lmdb_path)
328
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
329
330
    def _fit_model(
331
        self,
332
        corpus: DocumentCorpus,
333
        epochs: int,
334
        lmdb_map_size: int,
335
        n_jobs: int = 1,
336
    ) -> None:
337
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
338
        if corpus != "cached":
339
            if corpus.is_empty():
340
                raise NotSupportedException(
341
                    "Cannot train nn_ensemble project with no documents"
342
                )
343
            with env.begin(write=True, buffers=True) as txn:
344
                seq = LMDBDataset(txn)
345
                self._corpus_to_vectors(corpus, seq, n_jobs)
346
        else:
347
            self.info("Reusing cached training data from previous run.")
348
349
        # fit the model using a read-only view of the LMDB
350
        self.info("Training neural network model...")
351
        with env.begin(buffers=True) as txn:
352
            dataset = LMDBDataset(txn)
353
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
354
355
            # Training loop
356
            optimizer = torch.optim.AdamW(
357
                self._model.parameters(),
358
                lr=float(self.params["lr"]),
359
                weight_decay=0.01,
360
                eps=1e-08,
361
            )
362
            criterion = nn.BCELoss()
363
364
            for epoch in range(epochs):
365
                self._model.train()
366
                total_loss = 0.0
367
                total_ndcg = 0.0
368
                total_samples = 0
369
                tqdm_loader = tqdm(
370
                    dataloader,
371
                    desc=f"Epoch {epoch + 1}/{epochs}",
372
                    postfix={"loss": "0.000"},
373
                )
374
                for inputs, targets in tqdm_loader:
375
                    optimizer.zero_grad()
376
                    outputs = self._model(inputs)
377
                    loss = criterion(outputs, targets)
378
                    loss.backward()
379
                    optimizer.step()
380
381
                    batch_size = outputs.shape[0]
382
383
                    batch_ndcg = ndcg_batch(outputs, targets)
384
385
                    # Update stats
386
                    total_loss += loss.item() * batch_size
387
                    total_ndcg += batch_ndcg.item() * batch_size
388
                    total_samples += batch_size
389
390
                    # Update progress bar with batch loss
391
                    tqdm_loader.set_postfix(loss=loss.item())
392
393
                epoch_loss = total_loss / total_samples
394
                epoch_ndcg = total_ndcg / total_samples
395
                print(
396
                    f"Epoch {epoch + 1}/{epochs} "
397
                    f"- loss: {epoch_loss:.4f} "
398
                    f"- nDCG: {epoch_ndcg:.4f}"
399
                )
400
401
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
402
403
    def _learn(
404
        self,
405
        corpus: DocumentCorpus,
406
        params: dict[str, Any],
407
    ) -> None:
408
        self.initialize()
409
        self._fit_model(
410
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
411
        )
412