Passed
Pull Request — main (#926)
by Osma
07:29 queued 04:20
created

NNEnsembleModel.forward()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 2
dl 0
loc 9
rs 9.95
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from tqdm import tqdm
21
22
import annif.corpus
23
import annif.parallel
24
import annif.util
25
from annif.exception import (
26
    NotInitializedException,
27
    NotSupportedException,
28
    OperationFailedException,
29
)
30
from annif.suggestion import SuggestionBatch, vector_to_suggestions
31
32
from . import backend, ensemble
33
34
if TYPE_CHECKING:
35
    from annif.corpus.document import DocumentCorpus
36
37
logger = annif.logger
38
39
40
def idx_to_key(idx: int) -> bytes:
41
    """convert an integer index to a binary key for use in LMDB"""
42
    return b"%08d" % idx
43
44
45
def key_to_idx(key: memoryview | bytes) -> int:
46
    """convert a binary LMDB key to an integer index"""
47
    return int(key)
48
49
50
class LMDBDataset(Dataset):
51
    """A sequence of samples stored in a LMDB database."""
52
53
    def __init__(self, txn):
54
        super().__init__()
55
        self._txn = txn
56
        cursor = txn.cursor()
57
        if cursor.last():
58
            # Counter holds the number of samples in the database
59
            self._counter = key_to_idx(cursor.key()) + 1
60
        else:  # empty database
61
            self._counter = 0
62
63
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
64
        # use zero-padded 8-digit key
65
        key = idx_to_key(self._counter)
66
        self._counter += 1
67
        # convert the sample into a sparse matrix and serialize it as bytes
68
        sample = (csc_matrix(inputs), csr_matrix(targets))
69
        buf = BytesIO()
70
        joblib.dump(sample, buf)
71
        buf.seek(0)
72
        self._txn.put(key, buf.read())
73
74
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
75
        """get a particular sample"""
76
        cursor = self._txn.cursor()
77
        cursor.set_key(idx_to_key(idx))
78
        value = cursor.value()
79
        input_csr, target_csr = joblib.load(BytesIO(value))
80
        input_tensor = torch.from_numpy(input_csr.toarray())
81
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
82
        return input_tensor, target_tensor
83
84
    def __len__(self) -> int:
85
        """return the number of available samples"""
86
        return self._counter
87
88
89
class NNEnsembleModel(nn.Module):
90
    def __init__(
91
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
92
    ):
93
        super().__init__()
94
        self.model_config = {
95
            "input_dim": input_dim,
96
            "hidden_dim": hidden_dim,
97
            "output_dim": output_dim,
98
            "dropout_rate": dropout_rate,
99
        }
100
        self.flatten = nn.Flatten()
101
        self.dropout1 = nn.Dropout(dropout_rate)
102
        self.hidden = nn.Linear(input_dim, hidden_dim)
103
        self.dropout2 = nn.Dropout(dropout_rate)
104
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
105
        self.reset_parameters()
106
107
    def reset_parameters(self):
108
        nn.init.zeros_(self.delta_layer.weight)
109
        nn.init.zeros_(self.delta_layer.bias)
110
111
    def forward(self, inputs):
112
        mean = torch.mean(inputs, dim=1)
113
        x = self.flatten(inputs)
114
        x = self.dropout1(x)
115
        x = F.relu(self.hidden(x))
116
        x = self.dropout2(x)
117
        delta = self.delta_layer(x)
118
        corrected = mean + delta
119
        return torch.clamp(corrected, min=0.0, max=1.0)
120
121
    def save(self, filepath):
122
        torch.save(
123
            {
124
                "model_state_dict": self.state_dict(),
125
                "model_class": self.__class__.__name__,
126
                "model_config": self.model_config,
127
                "pytorch_version": str(torch.__version__),
128
                "python_version": sys.version,
129
            },
130
            filepath,
131
        )
132
133
    @classmethod
134
    def load(cls, filepath, map_location="cpu"):
135
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
136
        config = checkpoint["model_config"]
137
        model = cls(**config)
138
        model.load_state_dict(checkpoint["model_state_dict"])
139
        model.eval()
140
        return model
141
142
143
@torch.no_grad()
144
def ndcg_batch(preds: torch.Tensor, targets: torch.Tensor):
145
    """
146
    preds:   (B, N) float
147
    targets: (B, N) {0,1}
148
149
    Returns: mean nDCG across the batch
150
    """
151
    B, N = preds.shape
152
153
    sorted_idx = torch.argsort(preds, dim=1, descending=True)
154
    sorted_targets = torch.gather(targets, 1, sorted_idx)
155
156
    L = sorted_targets.size(1)
157
    ranks = torch.arange(1, L + 1, device=preds.device)
158
    discounts = 1.0 / torch.log2(ranks + 1)
159
160
    dcg = (sorted_targets * discounts).sum(dim=1)
161
162
    ideal_sorted = torch.sort(targets, dim=1, descending=True).values[:, :L]
163
    idcg = (ideal_sorted * discounts).sum(dim=1)
164
165
    ndcg = dcg / torch.clamp(idcg, min=1e-8)
166
167
    return ndcg.mean()
168
169
170
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
171
    """Neural network ensemble backend that combines results from multiple
172
    projects"""
173
174
    name = "nn_ensemble"
175
176
    MODEL_FILE = "nn-model.pt"
177
    LMDB_FILE = "nn-train.mdb"
178
179
    DEFAULT_PARAMETERS = {
180
        "nodes": 100,
181
        "dropout_rate": 0.2,
182
        "lr": 0.001,
183
        "epochs": 10,
184
        "learn-epochs": 1,
185
        "lmdb_map_size": 1024 * 1024 * 1024,
186
    }
187
188
    # defaults for uninitialized instances
189
    _model = None
190
191
    def initialize(self, parallel: bool = False) -> None:
192
        super().initialize(parallel)
193
        if self._model is not None:
194
            return  # already initialized
195
        if parallel:
196
            # Don't load model just before parallel execution,
197
            # since it won't work after forking worker processes
198
            return
199
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
200
        if not os.path.exists(model_filename):
201
            raise NotInitializedException(
202
                "model file {} not found".format(model_filename),
203
                backend_id=self.backend_id,
204
            )
205
        self.debug("loading model from {}".format(model_filename))
206
        try:
207
            self._model = NNEnsembleModel.load(model_filename)
208
        except Exception as err:
209
            message = (
210
                f"loading model from {model_filename}; "
211
                f'original error message: "{err}"'
212
            )
213
            raise OperationFailedException(message, backend_id=self.backend_id)
214
215
    def _merge_source_batches(
216
        self,
217
        batch_by_source: dict[str, SuggestionBatch],
218
        sources: list[tuple[str, float]],
219
        params: dict[str, Any],
220
    ) -> SuggestionBatch:
221
        src_weight = dict(sources)
222
        score_vectors = np.array(
223
            [
224
                [
225
                    np.sqrt(suggestions.as_vector())
226
                    * src_weight[project_id]
227
                    * len(batch_by_source)
228
                    for suggestions in batch
229
                ]
230
                for project_id, batch in batch_by_source.items()
231
            ],
232
            dtype=np.float32,
233
        )
234
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
235
        with torch.no_grad():
236
            prediction = self._model(score_vector_tensor)
237
        return SuggestionBatch.from_sequence(
238
            [
239
                vector_to_suggestions(row, limit=int(params["limit"]))
240
                for row in prediction
241
            ],
242
            self.project.subjects,
243
        )
244
245
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
246
        self.info("creating NN ensemble model")
247
248
        # Create PyTorch model
249
        input_dim = len(self.project.subjects) * len(sources)
250
        hidden_dim = int(self.params["nodes"])
251
        output_dim = len(self.project.subjects)
252
        dropout_rate = float(self.params["dropout_rate"])
253
254
        self._model = NNEnsembleModel(
255
            input_dim=input_dim,
256
            hidden_dim=hidden_dim,
257
            output_dim=output_dim,
258
            dropout_rate=dropout_rate,
259
        )
260
261
    def _train(
262
        self,
263
        corpus: DocumentCorpus,
264
        params: dict[str, Any],
265
        jobs: int = 0,
266
    ) -> None:
267
        sources = annif.util.parse_sources(self.params["sources"])
268
        self._create_model(sources)
269
        self._fit_model(
270
            corpus,
271
            epochs=int(params["epochs"]),
272
            lmdb_map_size=int(params["lmdb_map_size"]),
273
            n_jobs=jobs,
274
        )
275
276
    def _corpus_to_vectors(
277
        self,
278
        corpus: DocumentCorpus,
279
        seq: LMDBDataset,
280
        n_jobs: int,
281
    ) -> None:
282
        # pass corpus through all source projects
283
        sources = dict(annif.util.parse_sources(self.params["sources"]))
284
285
        # initialize the source projects before forking, to save memory
286
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
287
        for project_id in sources.keys():
288
            project = self.project.registry.get_project(project_id)
289
            project.initialize(parallel=True)
290
291
        psmap = annif.parallel.ProjectSuggestMap(
292
            self.project.registry,
293
            list(sources.keys()),
294
            backend_params=None,
295
            limit=None,
296
            threshold=0.0,
297
        )
298
299
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
300
301
        self.info("Processing training documents...")
302
        with pool_class(jobs) as pool:
303
            for hits, subject_set in pool.imap_unordered(
304
                psmap.suggest, corpus.documents
305
            ):
306
                doc_scores = []
307
                for project_id, p_hits in hits.items():
308
                    vector = p_hits.as_vector()
309
                    doc_scores.append(
310
                        np.sqrt(vector) * sources[project_id] * len(sources)
311
                    )
312
                score_vector = np.array(doc_scores, dtype=np.float32)
313
                true_vector = subject_set.as_vector(len(self.project.subjects))
314
                seq.add_sample(score_vector, true_vector)
315
316
    def _open_lmdb(self, cached, lmdb_map_size):
317
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
318
        if not cached and os.path.exists(lmdb_path):
319
            shutil.rmtree(lmdb_path)
320
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
321
322
    def _fit_model(
323
        self,
324
        corpus: DocumentCorpus,
325
        epochs: int,
326
        lmdb_map_size: int,
327
        n_jobs: int = 1,
328
    ) -> None:
329
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
330
        if corpus != "cached":
331
            if corpus.is_empty():
332
                raise NotSupportedException(
333
                    "Cannot train nn_ensemble project with no documents"
334
                )
335
            with env.begin(write=True, buffers=True) as txn:
336
                seq = LMDBDataset(txn)
337
                self._corpus_to_vectors(corpus, seq, n_jobs)
338
        else:
339
            self.info("Reusing cached training data from previous run.")
340
341
        # fit the model using a read-only view of the LMDB
342
        self.info("Training neural network model...")
343
        with env.begin(buffers=True) as txn:
344
            dataset = LMDBDataset(txn)
345
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
346
347
            # Training loop
348
            optimizer = torch.optim.AdamW(
349
                self._model.parameters(),
350
                lr=float(self.params["lr"]),
351
                weight_decay=0.01,
352
                eps=1e-08,
353
            )
354
            criterion = nn.BCELoss()
355
356
            for epoch in range(epochs):
357
                self._model.train()
358
                total_loss = 0.0
359
                total_ndcg = 0.0
360
                total_samples = 0
361
                tqdm_loader = tqdm(
362
                    dataloader,
363
                    desc=f"Epoch {epoch + 1}/{epochs}",
364
                    postfix={"loss": "0.000"},
365
                )
366
                for inputs, targets in tqdm_loader:
367
                    optimizer.zero_grad()
368
                    outputs = self._model(inputs)
369
                    loss = criterion(outputs, targets)
370
                    loss.backward()
371
                    optimizer.step()
372
373
                    batch_size, n_labels = outputs.shape
374
375
                    batch_ndcg = ndcg_batch(outputs, targets)
376
377
                    # Update stats
378
                    total_loss += loss.item() * batch_size
379
                    total_ndcg += batch_ndcg.item() * batch_size
380
                    total_samples += batch_size
381
382
                    # Update progress bar with batch loss
383
                    tqdm_loader.set_postfix(loss=loss.item())
384
385
                epoch_loss = total_loss / total_samples
386
                epoch_ndcg = total_ndcg / total_samples
387
                print(
388
                    f"Epoch {epoch + 1}/{epochs} "
389
                    f"- loss: {epoch_loss:.4f} "
390
                    f"- nDCG: {epoch_ndcg:.4f}"
391
                )
392
393
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
394
395
    def _learn(
396
        self,
397
        corpus: DocumentCorpus,
398
        params: dict[str, Any],
399
    ) -> None:
400
        self.initialize()
401
        self._fit_model(
402
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
403
        )
404