Passed
Pull Request — main (#926)
by Osma
05:59 queued 03:01
created

NNEnsembleModel.reset_parameters()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from torchmetrics.retrieval import RetrievalNormalizedDCG
21
from tqdm import tqdm
22
23
import annif.corpus
24
import annif.parallel
25
import annif.util
26
from annif.exception import (
27
    NotInitializedException,
28
    NotSupportedException,
29
    OperationFailedException,
30
)
31
from annif.suggestion import SuggestionBatch, vector_to_suggestions
32
33
from . import backend, ensemble
34
35
if TYPE_CHECKING:
36
    from annif.corpus.document import DocumentCorpus
37
38
logger = annif.logger
39
40
41
def idx_to_key(idx: int) -> bytes:
42
    """convert an integer index to a binary key for use in LMDB"""
43
    return b"%08d" % idx
44
45
46
def key_to_idx(key: memoryview | bytes) -> int:
47
    """convert a binary LMDB key to an integer index"""
48
    return int(key)
49
50
51
class LMDBDataset(Dataset):
52
    """A sequence of samples stored in a LMDB database."""
53
54
    def __init__(self, txn):
55
        super().__init__()
56
        self._txn = txn
57
        cursor = txn.cursor()
58
        if cursor.last():
59
            # Counter holds the number of samples in the database
60
            self._counter = key_to_idx(cursor.key()) + 1
61
        else:  # empty database
62
            self._counter = 0
63
64
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
65
        # use zero-padded 8-digit key
66
        key = idx_to_key(self._counter)
67
        self._counter += 1
68
        # convert the sample into a sparse matrix and serialize it as bytes
69
        sample = (csc_matrix(inputs), csr_matrix(targets))
70
        buf = BytesIO()
71
        joblib.dump(sample, buf)
72
        buf.seek(0)
73
        self._txn.put(key, buf.read())
74
75
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
76
        """get a particular sample"""
77
        cursor = self._txn.cursor()
78
        cursor.set_key(idx_to_key(idx))
79
        value = cursor.value()
80
        input_csr, target_csr = joblib.load(BytesIO(value))
81
        input_tensor = torch.from_numpy(input_csr.toarray())
82
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
83
        return input_tensor, target_tensor
84
85
    def __len__(self) -> int:
86
        """return the number of available samples"""
87
        return self._counter
88
89
90
class NNEnsembleModel(nn.Module):
91
    def __init__(
92
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
93
    ):
94
        super().__init__()
95
        self.model_config = {
96
            "input_dim": input_dim,
97
            "hidden_dim": hidden_dim,
98
            "output_dim": output_dim,
99
            "dropout_rate": dropout_rate,
100
        }
101
        self.flatten = nn.Flatten()
102
        self.dropout1 = nn.Dropout(dropout_rate)
103
        self.hidden = nn.Linear(input_dim, hidden_dim)
104
        self.dropout2 = nn.Dropout(dropout_rate)
105
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
106
        self.reset_parameters()
107
108
    def reset_parameters(self):
109
        nn.init.xavier_uniform_(self.hidden.weight)
110
        nn.init.zeros_(self.hidden.bias)
111
        nn.init.zeros_(self.delta_layer.weight)
112
        nn.init.zeros_(self.delta_layer.bias)
113
114
    def forward(self, inputs):
115
        mean = torch.mean(inputs, dim=1)
116
        x = self.flatten(inputs)
117
        x = self.dropout1(x)
118
        x = F.relu(self.hidden(x))
119
        x = self.dropout2(x)
120
        delta = self.delta_layer(x)
121
        corrected = mean + delta
122
        return torch.clamp(corrected, min=0.0, max=1.0)
123
124
    def save(self, filepath):
125
        torch.save(
126
            {
127
                "model_state_dict": self.state_dict(),
128
                "model_class": self.__class__.__name__,
129
                "model_config": self.model_config,
130
                "pytorch_version": str(torch.__version__),
131
                "python_version": sys.version,
132
            },
133
            filepath,
134
        )
135
136
    @classmethod
137
    def load(cls, filepath, map_location="cpu"):
138
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
139
        config = checkpoint["model_config"]
140
        model = cls(**config)
141
        model.load_state_dict(checkpoint["model_state_dict"])
142
        model.eval()
143
        return model
144
145
146
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
147
    """Neural network ensemble backend that combines results from multiple
148
    projects"""
149
150
    name = "nn_ensemble"
151
152
    MODEL_FILE = "nn-model.pt"
153
    LMDB_FILE = "nn-train.mdb"
154
155
    DEFAULT_PARAMETERS = {
156
        "nodes": 100,
157
        "dropout_rate": 0.2,
158
        "lr": 0.001,
159
        "epochs": 10,
160
        "learn-epochs": 1,
161
        "lmdb_map_size": 1024 * 1024 * 1024,
162
    }
163
164
    # defaults for uninitialized instances
165
    _model = None
166
167
    def initialize(self, parallel: bool = False) -> None:
168
        super().initialize(parallel)
169
        if self._model is not None:
170
            return  # already initialized
171
        if parallel:
172
            # Don't load model just before parallel execution,
173
            # since it won't work after forking worker processes
174
            return
175
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
176
        if not os.path.exists(model_filename):
177
            raise NotInitializedException(
178
                "model file {} not found".format(model_filename),
179
                backend_id=self.backend_id,
180
            )
181
        self.debug("loading model from {}".format(model_filename))
182
        try:
183
            self._model = NNEnsembleModel.load(model_filename)
184
        except Exception as err:
185
            message = (
186
                f"loading model from {model_filename}; "
187
                f'original error message: "{err}"'
188
            )
189
            raise OperationFailedException(message, backend_id=self.backend_id)
190
191
    def _merge_source_batches(
192
        self,
193
        batch_by_source: dict[str, SuggestionBatch],
194
        sources: list[tuple[str, float]],
195
        params: dict[str, Any],
196
    ) -> SuggestionBatch:
197
        src_weight = dict(sources)
198
        score_vectors = np.array(
199
            [
200
                [
201
                    np.sqrt(suggestions.as_vector())
202
                    * src_weight[project_id]
203
                    * len(batch_by_source)
204
                    for suggestions in batch
205
                ]
206
                for project_id, batch in batch_by_source.items()
207
            ],
208
            dtype=np.float32,
209
        )
210
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
211
        with torch.no_grad():
212
            prediction = self._model(score_vector_tensor)
213
        return SuggestionBatch.from_sequence(
214
            [
215
                vector_to_suggestions(row, limit=int(params["limit"]))
216
                for row in prediction
217
            ],
218
            self.project.subjects,
219
        )
220
221
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
222
        self.info("creating NN ensemble model")
223
224
        # Create PyTorch model
225
        input_dim = len(self.project.subjects) * len(sources)
226
        hidden_dim = int(self.params["nodes"])
227
        output_dim = len(self.project.subjects)
228
        dropout_rate = float(self.params["dropout_rate"])
229
230
        self._model = NNEnsembleModel(
231
            input_dim=input_dim,
232
            hidden_dim=hidden_dim,
233
            output_dim=output_dim,
234
            dropout_rate=dropout_rate,
235
        )
236
237
    def _train(
238
        self,
239
        corpus: DocumentCorpus,
240
        params: dict[str, Any],
241
        jobs: int = 0,
242
    ) -> None:
243
        sources = annif.util.parse_sources(self.params["sources"])
244
        self._create_model(sources)
245
        self._fit_model(
246
            corpus,
247
            epochs=int(params["epochs"]),
248
            lmdb_map_size=int(params["lmdb_map_size"]),
249
            n_jobs=jobs,
250
        )
251
252
    def _corpus_to_vectors(
253
        self,
254
        corpus: DocumentCorpus,
255
        seq: LMDBDataset,
256
        n_jobs: int,
257
    ) -> None:
258
        # pass corpus through all source projects
259
        sources = dict(annif.util.parse_sources(self.params["sources"]))
260
261
        # initialize the source projects before forking, to save memory
262
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
263
        for project_id in sources.keys():
264
            project = self.project.registry.get_project(project_id)
265
            project.initialize(parallel=True)
266
267
        psmap = annif.parallel.ProjectSuggestMap(
268
            self.project.registry,
269
            list(sources.keys()),
270
            backend_params=None,
271
            limit=None,
272
            threshold=0.0,
273
        )
274
275
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
276
277
        self.info("Processing training documents...")
278
        with pool_class(jobs) as pool:
279
            for hits, subject_set in pool.imap_unordered(
280
                psmap.suggest, corpus.documents
281
            ):
282
                doc_scores = []
283
                for project_id, p_hits in hits.items():
284
                    vector = p_hits.as_vector()
285
                    doc_scores.append(
286
                        np.sqrt(vector) * sources[project_id] * len(sources)
287
                    )
288
                score_vector = np.array(doc_scores, dtype=np.float32)
289
                true_vector = subject_set.as_vector(len(self.project.subjects))
290
                seq.add_sample(score_vector, true_vector)
291
292
    def _open_lmdb(self, cached, lmdb_map_size):
293
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
294
        if not cached and os.path.exists(lmdb_path):
295
            shutil.rmtree(lmdb_path)
296
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
297
298
    def _fit_model(
299
        self,
300
        corpus: DocumentCorpus,
301
        epochs: int,
302
        lmdb_map_size: int,
303
        n_jobs: int = 1,
304
    ) -> None:
305
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
306
        if corpus != "cached":
307
            if corpus.is_empty():
308
                raise NotSupportedException(
309
                    "Cannot train nn_ensemble project with no documents"
310
                )
311
            with env.begin(write=True, buffers=True) as txn:
312
                seq = LMDBDataset(txn)
313
                self._corpus_to_vectors(corpus, seq, n_jobs)
314
        else:
315
            self.info("Reusing cached training data from previous run.")
316
317
        # fit the model using a read-only view of the LMDB
318
        self.info("Training neural network model...")
319
        with env.begin(buffers=True) as txn:
320
            dataset = LMDBDataset(txn)
321
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
322
323
            # Training loop
324
            optimizer = torch.optim.Adam(
325
                self._model.parameters(),
326
                lr=float(self.params["lr"]),
327
                weight_decay=0,
328
                eps=1e-08,
329
            )
330
            criterion = nn.BCELoss()
331
            ndcg_metric = RetrievalNormalizedDCG(top_k=None)
332
333
            for epoch in range(epochs):
334
                self._model.train()
335
                ndcg_metric.reset()
336
                total_loss = 0.0
337
                total_samples = 0
338
                tqdm_loader = tqdm(
339
                    dataloader,
340
                    desc=f"Epoch {epoch + 1}/{epochs}",
341
                    postfix={"loss": "0.000"},
342
                )
343
                for inputs, targets in tqdm_loader:
344
                    optimizer.zero_grad()
345
                    outputs = self._model(inputs)
346
                    loss = criterion(outputs, targets)
347
                    loss.backward()
348
                    optimizer.step()
349
350
                    batch_size, n_labels = outputs.shape
351
352
                    # Build indexes; each sample is a separate query for nDCG
353
                    indexes = torch.repeat_interleave(
354
                        torch.arange(batch_size, device=outputs.device), n_labels
355
                    )
356
                    ndcg_metric.update(
357
                        outputs.reshape(-1), targets.reshape(-1), indexes=indexes
358
                    )
359
360
                    # Update loss stats
361
                    total_loss += loss.item() * batch_size
362
                    total_samples += batch_size
363
364
                    # Update progress bar with batch loss
365
                    tqdm_loader.set_postfix(loss=loss.item())
366
367
                epoch_loss = total_loss / total_samples
368
                epoch_ndcg = ndcg_metric.compute().item()
369
                print(
370
                    f"Epoch {epoch + 1}/{epochs} "
371
                    f"- loss: {epoch_loss:.4f} "
372
                    f"- nDCG: {epoch_ndcg:.4f}"
373
                )
374
375
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
376
377
    def _learn(
378
        self,
379
        corpus: DocumentCorpus,
380
        params: dict[str, Any],
381
    ) -> None:
382
        self.initialize()
383
        self._fit_model(
384
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
385
        )
386