Passed
Pull Request — main (#926)
by Osma
06:42 queued 03:12
created

NNEnsembleModel.forward()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from torchmetrics.retrieval import RetrievalNormalizedDCG
21
from tqdm import tqdm
22
23
import annif.corpus
24
import annif.parallel
25
import annif.util
26
from annif.exception import (
27
    NotInitializedException,
28
    NotSupportedException,
29
    OperationFailedException,
30
)
31
from annif.suggestion import SuggestionBatch, vector_to_suggestions
32
33
from . import backend, ensemble
34
35
if TYPE_CHECKING:
36
    from annif.corpus.document import DocumentCorpus
37
38
logger = annif.logger
39
40
41
def idx_to_key(idx: int) -> bytes:
42
    """convert an integer index to a binary key for use in LMDB"""
43
    return b"%08d" % idx
44
45
46
def key_to_idx(key: memoryview | bytes) -> int:
47
    """convert a binary LMDB key to an integer index"""
48
    return int(key)
49
50
51
class LMDBDataset(Dataset):
52
    """A sequence of samples stored in a LMDB database."""
53
54
    def __init__(self, txn):
55
        super().__init__()
56
        self._txn = txn
57
        cursor = txn.cursor()
58
        if cursor.last():
59
            # Counter holds the number of samples in the database
60
            self._counter = key_to_idx(cursor.key()) + 1
61
        else:  # empty database
62
            self._counter = 0
63
64
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
65
        # use zero-padded 8-digit key
66
        key = idx_to_key(self._counter)
67
        self._counter += 1
68
        # convert the sample into a sparse matrix and serialize it as bytes
69
        sample = (csc_matrix(inputs), csr_matrix(targets))
70
        buf = BytesIO()
71
        joblib.dump(sample, buf)
72
        buf.seek(0)
73
        self._txn.put(key, buf.read())
74
75
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
76
        """get a particular sample"""
77
        cursor = self._txn.cursor()
78
        cursor.set_key(idx_to_key(idx))
79
        value = cursor.value()
80
        input_csr, target_csr = joblib.load(BytesIO(value))
81
        input_tensor = torch.from_numpy(input_csr.toarray())
82
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
83
        return input_tensor, target_tensor
84
85
    def __len__(self) -> int:
86
        """return the number of available samples"""
87
        return self._counter
88
89
90
class NNEnsembleModel(nn.Module):
91
    def __init__(
92
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
93
    ):
94
        super().__init__()
95
        self.model_config = {
96
            "input_dim": input_dim,
97
            "hidden_dim": hidden_dim,
98
            "output_dim": output_dim,
99
            "dropout_rate": dropout_rate,
100
        }
101
        self.flatten = nn.Flatten()
102
        self.dropout1 = nn.Dropout(dropout_rate)
103
        self.hidden = nn.Linear(input_dim, hidden_dim)
104
        self.dropout2 = nn.Dropout(dropout_rate)
105
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
106
107
    def forward(self, inputs):
108
        mean = torch.mean(inputs, dim=1)
109
        x = self.flatten(inputs)
110
        x = self.dropout1(x)
111
        x = F.relu(self.hidden(x))
112
        x = self.dropout2(x)
113
        delta = self.delta_layer(x)
114
        return mean + delta
115
116
    def save(self, filepath):
117
        torch.save(
118
            {
119
                "model_state_dict": self.state_dict(),
120
                "model_class": self.__class__.__name__,
121
                "model_config": self.model_config,
122
                "pytorch_version": str(torch.__version__),
123
                "python_version": sys.version,
124
            },
125
            filepath,
126
        )
127
128
    @classmethod
129
    def load(cls, filepath, map_location="cpu"):
130
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
131
        config = checkpoint["model_config"]
132
        model = cls(**config)
133
        model.load_state_dict(checkpoint["model_state_dict"])
134
        model.eval()
135
        return model
136
137
138
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
139
    """Neural network ensemble backend that combines results from multiple
140
    projects"""
141
142
    name = "nn_ensemble"
143
144
    MODEL_FILE = "nn-model.pt"
145
    LMDB_FILE = "nn-train.mdb"
146
147
    DEFAULT_PARAMETERS = {
148
        "nodes": 100,
149
        "dropout_rate": 0.2,
150
        "optimizer": "adam",
151
        "lr": 0.001,
152
        "epochs": 10,
153
        "learn-epochs": 1,
154
        "lmdb_map_size": 1024 * 1024 * 1024,
155
    }
156
157
    # defaults for uninitialized instances
158
    _model = None
159
160
    def initialize(self, parallel: bool = False) -> None:
161
        super().initialize(parallel)
162
        if self._model is not None:
163
            return  # already initialized
164
        if parallel:
165
            # Don't load model just before parallel execution,
166
            # since it won't work after forking worker processes
167
            return
168
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
169
        if not os.path.exists(model_filename):
170
            raise NotInitializedException(
171
                "model file {} not found".format(model_filename),
172
                backend_id=self.backend_id,
173
            )
174
        self.debug("loading model from {}".format(model_filename))
175
        try:
176
            self._model = NNEnsembleModel.load(model_filename)
177
        except Exception as err:
178
            message = (
179
                f"loading model from {model_filename}; "
180
                f'original error message: "{err}"'
181
            )
182
            raise OperationFailedException(message, backend_id=self.backend_id)
183
184
    def _merge_source_batches(
185
        self,
186
        batch_by_source: dict[str, SuggestionBatch],
187
        sources: list[tuple[str, float]],
188
        params: dict[str, Any],
189
    ) -> SuggestionBatch:
190
        src_weight = dict(sources)
191
        score_vectors = np.array(
192
            [
193
                [
194
                    np.sqrt(suggestions.as_vector())
195
                    * src_weight[project_id]
196
                    * len(batch_by_source)
197
                    for suggestions in batch
198
                ]
199
                for project_id, batch in batch_by_source.items()
200
            ],
201
            dtype=np.float32,
202
        )
203
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
204
        with torch.no_grad():
205
            prediction = self._model(score_vector_tensor)
206
        return SuggestionBatch.from_sequence(
207
            [
208
                vector_to_suggestions(row, limit=int(params["limit"]))
209
                for row in prediction
210
            ],
211
            self.project.subjects,
212
        )
213
214
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
215
        self.info("creating NN ensemble model")
216
217
        # Create PyTorch model
218
        input_dim = len(self.project.subjects) * len(sources)
219
        hidden_dim = int(self.params["nodes"])
220
        output_dim = len(self.project.subjects)
221
        dropout_rate = float(self.params["dropout_rate"])
222
223
        self._model = NNEnsembleModel(
224
            input_dim=input_dim,
225
            hidden_dim=hidden_dim,
226
            output_dim=output_dim,
227
            dropout_rate=dropout_rate,
228
        )
229
230
    #        summary = []
231
    #        self._model.summary(print_fn=summary.append)
232
    #        self.debug("Created model: \n" + "\n".join(summary))
233
234
    def _train(
235
        self,
236
        corpus: DocumentCorpus,
237
        params: dict[str, Any],
238
        jobs: int = 0,
239
    ) -> None:
240
        sources = annif.util.parse_sources(self.params["sources"])
241
        self._create_model(sources)
242
        self._fit_model(
243
            corpus,
244
            epochs=int(params["epochs"]),
245
            lmdb_map_size=int(params["lmdb_map_size"]),
246
            n_jobs=jobs,
247
        )
248
249
    def _corpus_to_vectors(
250
        self,
251
        corpus: DocumentCorpus,
252
        seq: LMDBDataset,
253
        n_jobs: int,
254
    ) -> None:
255
        # pass corpus through all source projects
256
        sources = dict(annif.util.parse_sources(self.params["sources"]))
257
258
        # initialize the source projects before forking, to save memory
259
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
260
        for project_id in sources.keys():
261
            project = self.project.registry.get_project(project_id)
262
            project.initialize(parallel=True)
263
264
        psmap = annif.parallel.ProjectSuggestMap(
265
            self.project.registry,
266
            list(sources.keys()),
267
            backend_params=None,
268
            limit=None,
269
            threshold=0.0,
270
        )
271
272
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
273
274
        self.info("Processing training documents...")
275
        with pool_class(jobs) as pool:
276
            for hits, subject_set in pool.imap_unordered(
277
                psmap.suggest, corpus.documents
278
            ):
279
                doc_scores = []
280
                for project_id, p_hits in hits.items():
281
                    vector = p_hits.as_vector()
282
                    doc_scores.append(
283
                        np.sqrt(vector) * sources[project_id] * len(sources)
284
                    )
285
                score_vector = np.array(doc_scores, dtype=np.float32)
286
                true_vector = subject_set.as_vector(len(self.project.subjects))
287
                seq.add_sample(score_vector, true_vector)
288
289
    def _open_lmdb(self, cached, lmdb_map_size):
290
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
291
        if not cached and os.path.exists(lmdb_path):
292
            shutil.rmtree(lmdb_path)
293
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
294
295
    def _fit_model(
296
        self,
297
        corpus: DocumentCorpus,
298
        epochs: int,
299
        lmdb_map_size: int,
300
        n_jobs: int = 1,
301
    ) -> None:
302
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
303
        if corpus != "cached":
304
            if corpus.is_empty():
305
                raise NotSupportedException(
306
                    "Cannot train nn_ensemble project with no documents"
307
                )
308
            with env.begin(write=True, buffers=True) as txn:
309
                seq = LMDBDataset(txn)
310
                self._corpus_to_vectors(corpus, seq, n_jobs)
311
        else:
312
            self.info("Reusing cached training data from previous run.")
313
314
        # fit the model using a read-only view of the LMDB
315
        self.info("Training neural network model...")
316
        with env.begin(buffers=True) as txn:
317
            dataset = LMDBDataset(txn)
318
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
319
320
            # Training loop
321
            optimizer = torch.optim.Adam(
322
                self._model.parameters(), lr=float(self.params["lr"])
323
            )
324
            criterion = nn.BCEWithLogitsLoss()
325
            ndcg_metric = RetrievalNormalizedDCG(top_k=None)
326
327
            self._model.train()
328
            for epoch in range(epochs):
329
                ndcg_metric.reset()
330
                total_loss = 0.0
331
                total_samples = 0
332
                tqdm_loader = tqdm(
333
                    dataloader,
334
                    desc=f"Epoch {epoch + 1}/{epochs}",
335
                    postfix={"loss": "0.000"},
336
                )
337
                for inputs, targets in tqdm_loader:
338
                    optimizer.zero_grad()
339
                    outputs = self._model(inputs)
340
                    loss = criterion(outputs, targets)
341
                    loss.backward()
342
                    optimizer.step()
343
344
                    batch_size, n_labels = outputs.shape
345
346
                    # Build indexes; each sample is a separate query for nDCG
347
                    indexes = torch.repeat_interleave(
348
                        torch.arange(batch_size, device=outputs.device), n_labels
349
                    )
350
                    ndcg_metric.update(
351
                        outputs.reshape(-1), targets.reshape(-1), indexes=indexes
352
                    )
353
354
                    # Update loss stats
355
                    total_loss += loss.item() * batch_size
356
                    total_samples += batch_size
357
358
                    # Update progress bar with batch loss
359
                    tqdm_loader.set_postfix(loss=loss.item())
360
361
                epoch_loss = total_loss / total_samples
362
                epoch_ndcg = ndcg_metric.compute().item()
363
                print(
364
                    f"Epoch {epoch + 1}/{epochs} "
365
                    f"- loss: {epoch_loss:.4f} "
366
                    f"- nDCG: {epoch_ndcg:.4f}"
367
                )
368
369
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
370
371
    def _learn(
372
        self,
373
        corpus: DocumentCorpus,
374
        params: dict[str, Any],
375
    ) -> None:
376
        self.initialize()
377
        self._fit_model(
378
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
379
        )
380