Passed
Pull Request — main (#926)
by Osma
06:05 queued 03:03
created

NNEnsembleModel.forward()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from torchmetrics.retrieval import RetrievalNormalizedDCG
21
from tqdm import tqdm
22
23
import annif.corpus
24
import annif.parallel
25
import annif.util
26
from annif.exception import (
27
    NotInitializedException,
28
    NotSupportedException,
29
    OperationFailedException,
30
)
31
from annif.suggestion import SuggestionBatch, vector_to_suggestions
32
33
from . import backend, ensemble
34
35
if TYPE_CHECKING:
36
    from annif.corpus.document import DocumentCorpus
37
38
logger = annif.logger
39
40
41
def idx_to_key(idx: int) -> bytes:
42
    """convert an integer index to a binary key for use in LMDB"""
43
    return b"%08d" % idx
44
45
46
def key_to_idx(key: memoryview | bytes) -> int:
47
    """convert a binary LMDB key to an integer index"""
48
    return int(key)
49
50
51
class LMDBDataset(Dataset):
52
    """A sequence of samples stored in a LMDB database."""
53
54
    def __init__(self, txn):
55
        super().__init__()
56
        self._txn = txn
57
        cursor = txn.cursor()
58
        if cursor.last():
59
            # Counter holds the number of samples in the database
60
            self._counter = key_to_idx(cursor.key()) + 1
61
        else:  # empty database
62
            self._counter = 0
63
64
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
65
        # use zero-padded 8-digit key
66
        key = idx_to_key(self._counter)
67
        self._counter += 1
68
        # convert the sample into a sparse matrix and serialize it as bytes
69
        sample = (csc_matrix(inputs), csr_matrix(targets))
70
        buf = BytesIO()
71
        joblib.dump(sample, buf)
72
        buf.seek(0)
73
        self._txn.put(key, buf.read())
74
75
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
76
        """get a particular sample"""
77
        cursor = self._txn.cursor()
78
        cursor.set_key(idx_to_key(idx))
79
        value = cursor.value()
80
        input_csr, target_csr = joblib.load(BytesIO(value))
81
        input_tensor = torch.from_numpy(input_csr.toarray())
82
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
83
        return input_tensor, target_tensor
84
85
    def __len__(self) -> int:
86
        """return the number of available samples"""
87
        return self._counter
88
89
90
class NNEnsembleModel(nn.Module):
91
    def __init__(
92
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
93
    ):
94
        super().__init__()
95
        self.model_config = {
96
            "input_dim": input_dim,
97
            "hidden_dim": hidden_dim,
98
            "output_dim": output_dim,
99
            "dropout_rate": dropout_rate,
100
        }
101
        self.flatten = nn.Flatten()
102
        self.dropout1 = nn.Dropout(dropout_rate)
103
        self.hidden = nn.Linear(input_dim, hidden_dim)
104
        self.dropout2 = nn.Dropout(dropout_rate)
105
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
106
        self.reset_parameters()
107
108
    def reset_parameters(self):
109
        nn.init.xavier_uniform_(self.hidden.weight)
110
        nn.init.zeros_(self.hidden.bias)
111
        nn.init.zeros_(self.delta_layer.weight)
112
        nn.init.zeros_(self.delta_layer.bias)
113
114
    def forward(self, inputs):
115
        mean = torch.mean(inputs, dim=1)
116
        x = self.flatten(inputs)
117
        x = self.dropout1(x)
118
        x = F.relu(self.hidden(x))
119
        x = self.dropout2(x)
120
        delta = self.delta_layer(x)
121
        return mean + delta
122
123
    def save(self, filepath):
124
        torch.save(
125
            {
126
                "model_state_dict": self.state_dict(),
127
                "model_class": self.__class__.__name__,
128
                "model_config": self.model_config,
129
                "pytorch_version": str(torch.__version__),
130
                "python_version": sys.version,
131
            },
132
            filepath,
133
        )
134
135
    @classmethod
136
    def load(cls, filepath, map_location="cpu"):
137
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
138
        config = checkpoint["model_config"]
139
        model = cls(**config)
140
        model.load_state_dict(checkpoint["model_state_dict"])
141
        model.eval()
142
        return model
143
144
145
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
146
    """Neural network ensemble backend that combines results from multiple
147
    projects"""
148
149
    name = "nn_ensemble"
150
151
    MODEL_FILE = "nn-model.pt"
152
    LMDB_FILE = "nn-train.mdb"
153
154
    DEFAULT_PARAMETERS = {
155
        "nodes": 100,
156
        "dropout_rate": 0.2,
157
        "lr": 0.001,
158
        "epochs": 10,
159
        "learn-epochs": 1,
160
        "lmdb_map_size": 1024 * 1024 * 1024,
161
    }
162
163
    # defaults for uninitialized instances
164
    _model = None
165
166
    def initialize(self, parallel: bool = False) -> None:
167
        super().initialize(parallel)
168
        if self._model is not None:
169
            return  # already initialized
170
        if parallel:
171
            # Don't load model just before parallel execution,
172
            # since it won't work after forking worker processes
173
            return
174
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
175
        if not os.path.exists(model_filename):
176
            raise NotInitializedException(
177
                "model file {} not found".format(model_filename),
178
                backend_id=self.backend_id,
179
            )
180
        self.debug("loading model from {}".format(model_filename))
181
        try:
182
            self._model = NNEnsembleModel.load(model_filename)
183
        except Exception as err:
184
            message = (
185
                f"loading model from {model_filename}; "
186
                f'original error message: "{err}"'
187
            )
188
            raise OperationFailedException(message, backend_id=self.backend_id)
189
190
    def _merge_source_batches(
191
        self,
192
        batch_by_source: dict[str, SuggestionBatch],
193
        sources: list[tuple[str, float]],
194
        params: dict[str, Any],
195
    ) -> SuggestionBatch:
196
        src_weight = dict(sources)
197
        score_vectors = np.array(
198
            [
199
                [
200
                    np.sqrt(suggestions.as_vector())
201
                    * src_weight[project_id]
202
                    * len(batch_by_source)
203
                    for suggestions in batch
204
                ]
205
                for project_id, batch in batch_by_source.items()
206
            ],
207
            dtype=np.float32,
208
        )
209
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
210
        with torch.no_grad():
211
            prediction = self._model(score_vector_tensor)
212
        return SuggestionBatch.from_sequence(
213
            [
214
                vector_to_suggestions(row, limit=int(params["limit"]))
215
                for row in prediction
216
            ],
217
            self.project.subjects,
218
        )
219
220
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
221
        self.info("creating NN ensemble model")
222
223
        # Create PyTorch model
224
        input_dim = len(self.project.subjects) * len(sources)
225
        hidden_dim = int(self.params["nodes"])
226
        output_dim = len(self.project.subjects)
227
        dropout_rate = float(self.params["dropout_rate"])
228
229
        self._model = NNEnsembleModel(
230
            input_dim=input_dim,
231
            hidden_dim=hidden_dim,
232
            output_dim=output_dim,
233
            dropout_rate=dropout_rate,
234
        )
235
236
    def _train(
237
        self,
238
        corpus: DocumentCorpus,
239
        params: dict[str, Any],
240
        jobs: int = 0,
241
    ) -> None:
242
        sources = annif.util.parse_sources(self.params["sources"])
243
        self._create_model(sources)
244
        self._fit_model(
245
            corpus,
246
            epochs=int(params["epochs"]),
247
            lmdb_map_size=int(params["lmdb_map_size"]),
248
            n_jobs=jobs,
249
        )
250
251
    def _corpus_to_vectors(
252
        self,
253
        corpus: DocumentCorpus,
254
        seq: LMDBDataset,
255
        n_jobs: int,
256
    ) -> None:
257
        # pass corpus through all source projects
258
        sources = dict(annif.util.parse_sources(self.params["sources"]))
259
260
        # initialize the source projects before forking, to save memory
261
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
262
        for project_id in sources.keys():
263
            project = self.project.registry.get_project(project_id)
264
            project.initialize(parallel=True)
265
266
        psmap = annif.parallel.ProjectSuggestMap(
267
            self.project.registry,
268
            list(sources.keys()),
269
            backend_params=None,
270
            limit=None,
271
            threshold=0.0,
272
        )
273
274
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
275
276
        self.info("Processing training documents...")
277
        with pool_class(jobs) as pool:
278
            for hits, subject_set in pool.imap_unordered(
279
                psmap.suggest, corpus.documents
280
            ):
281
                doc_scores = []
282
                for project_id, p_hits in hits.items():
283
                    vector = p_hits.as_vector()
284
                    doc_scores.append(
285
                        np.sqrt(vector) * sources[project_id] * len(sources)
286
                    )
287
                score_vector = np.array(doc_scores, dtype=np.float32)
288
                true_vector = subject_set.as_vector(len(self.project.subjects))
289
                seq.add_sample(score_vector, true_vector)
290
291
    def _open_lmdb(self, cached, lmdb_map_size):
292
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
293
        if not cached and os.path.exists(lmdb_path):
294
            shutil.rmtree(lmdb_path)
295
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
296
297
    def _fit_model(
298
        self,
299
        corpus: DocumentCorpus,
300
        epochs: int,
301
        lmdb_map_size: int,
302
        n_jobs: int = 1,
303
    ) -> None:
304
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
305
        if corpus != "cached":
306
            if corpus.is_empty():
307
                raise NotSupportedException(
308
                    "Cannot train nn_ensemble project with no documents"
309
                )
310
            with env.begin(write=True, buffers=True) as txn:
311
                seq = LMDBDataset(txn)
312
                self._corpus_to_vectors(corpus, seq, n_jobs)
313
        else:
314
            self.info("Reusing cached training data from previous run.")
315
316
        # fit the model using a read-only view of the LMDB
317
        self.info("Training neural network model...")
318
        with env.begin(buffers=True) as txn:
319
            dataset = LMDBDataset(txn)
320
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
321
322
            # Training loop
323
            optimizer = torch.optim.Adam(
324
                self._model.parameters(),
325
                lr=float(self.params["lr"]),
326
                weight_decay=0,
327
                eps=1e-08,
328
            )
329
            criterion = nn.BCEWithLogitsLoss()
330
            ndcg_metric = RetrievalNormalizedDCG(top_k=None)
331
332
            for epoch in range(epochs):
333
                self._model.train()
334
                ndcg_metric.reset()
335
                total_loss = 0.0
336
                total_samples = 0
337
                tqdm_loader = tqdm(
338
                    dataloader,
339
                    desc=f"Epoch {epoch + 1}/{epochs}",
340
                    postfix={"loss": "0.000"},
341
                )
342
                for inputs, targets in tqdm_loader:
343
                    optimizer.zero_grad()
344
                    outputs = self._model(inputs)
345
                    loss = criterion(outputs, targets)
346
                    loss.backward()
347
                    optimizer.step()
348
349
                    batch_size, n_labels = outputs.shape
350
351
                    # Build indexes; each sample is a separate query for nDCG
352
                    indexes = torch.repeat_interleave(
353
                        torch.arange(batch_size, device=outputs.device), n_labels
354
                    )
355
                    ndcg_metric.update(
356
                        outputs.reshape(-1), targets.reshape(-1), indexes=indexes
357
                    )
358
359
                    # Update loss stats
360
                    total_loss += loss.item() * batch_size
361
                    total_samples += batch_size
362
363
                    # Update progress bar with batch loss
364
                    tqdm_loader.set_postfix(loss=loss.item())
365
366
                epoch_loss = total_loss / total_samples
367
                epoch_ndcg = ndcg_metric.compute().item()
368
                print(
369
                    f"Epoch {epoch + 1}/{epochs} "
370
                    f"- loss: {epoch_loss:.4f} "
371
                    f"- nDCG: {epoch_ndcg:.4f}"
372
                )
373
374
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
375
376
    def _learn(
377
        self,
378
        corpus: DocumentCorpus,
379
        params: dict[str, Any],
380
    ) -> None:
381
        self.initialize()
382
        self._fit_model(
383
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
384
        )
385