Passed
Pull Request — main (#926)
by Osma
06:33 queued 03:02
created

annif.backend.nn_ensemble.LMDBSequence.__len__()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
"""Neural network based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import os.path
7
import shutil
8
import sys
9
from io import BytesIO
10
from typing import TYPE_CHECKING, Any
11
12
import joblib
13
import lmdb
14
import numpy as np
15
import torch
16
import torch.nn as nn
17
import torch.nn.functional as F
18
from scipy.sparse import csc_matrix, csr_matrix
19
from torch.utils.data import DataLoader, Dataset
20
from tqdm import tqdm
21
22
import annif.corpus
23
import annif.parallel
24
import annif.util
25
from annif.exception import (
26
    NotInitializedException,
27
    NotSupportedException,
28
    OperationFailedException,
29
)
30
from annif.suggestion import SuggestionBatch, vector_to_suggestions
31
32
from . import backend, ensemble
33
34
if TYPE_CHECKING:
35
    from annif.corpus.document import DocumentCorpus
36
37
logger = annif.logger
38
39
40
def idx_to_key(idx: int) -> bytes:
41
    """convert an integer index to a binary key for use in LMDB"""
42
    return b"%08d" % idx
43
44
45
def key_to_idx(key: memoryview | bytes) -> int:
46
    """convert a binary LMDB key to an integer index"""
47
    return int(key)
48
49
50
class LMDBDataset(Dataset):
51
    """A sequence of samples stored in a LMDB database."""
52
53
    def __init__(self, txn):
54
        super().__init__()
55
        self._txn = txn
56
        cursor = txn.cursor()
57
        if cursor.last():
58
            # Counter holds the number of samples in the database
59
            self._counter = key_to_idx(cursor.key()) + 1
60
        else:  # empty database
61
            self._counter = 0
62
63
    def add_sample(self, inputs: np.ndarray, targets: np.ndarray) -> None:
64
        # use zero-padded 8-digit key
65
        key = idx_to_key(self._counter)
66
        self._counter += 1
67
        # convert the sample into a sparse matrix and serialize it as bytes
68
        sample = (csc_matrix(inputs), csr_matrix(targets))
69
        buf = BytesIO()
70
        joblib.dump(sample, buf)
71
        buf.seek(0)
72
        self._txn.put(key, buf.read())
73
74
    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
75
        """get a particular sample"""
76
        cursor = self._txn.cursor()
77
        cursor.set_key(idx_to_key(idx))
78
        value = cursor.value()
79
        input_csr, target_csr = joblib.load(BytesIO(value))
80
        input_tensor = torch.from_numpy(input_csr.toarray())
81
        target_tensor = torch.from_numpy(target_csr.toarray()[0]).float()
82
        return input_tensor, target_tensor
83
84
    def __len__(self) -> int:
85
        """return the number of available samples"""
86
        return self._counter
87
88
89
class NNEnsembleModel(nn.Module):
90
    def __init__(
91
        self, input_dim: int, hidden_dim: int, output_dim: int, dropout_rate: float
92
    ):
93
        super().__init__()
94
        self.model_config = {
95
            "input_dim": input_dim,
96
            "hidden_dim": hidden_dim,
97
            "output_dim": output_dim,
98
            "dropout_rate": dropout_rate,
99
        }
100
        self.flatten = nn.Flatten()
101
        self.dropout1 = nn.Dropout(dropout_rate)
102
        self.hidden = nn.Linear(input_dim, hidden_dim)
103
        self.dropout2 = nn.Dropout(dropout_rate)
104
        self.delta_layer = nn.Linear(hidden_dim, output_dim)
105
106
    def forward(self, inputs):
107
        mean = torch.mean(inputs, dim=1)
108
        x = self.flatten(inputs)
109
        x = self.dropout1(x)
110
        x = F.relu(self.hidden(x))
111
        x = self.dropout2(x)
112
        delta = self.delta_layer(x)
113
        return mean + delta
114
115
    def save(self, filepath):
116
        torch.save(
117
            {
118
                "model_state_dict": self.state_dict(),
119
                "model_class": self.__class__.__name__,
120
                "model_config": self.model_config,
121
                "pytorch_version": str(torch.__version__),
122
                "python_version": sys.version,
123
            },
124
            filepath,
125
        )
126
127
    @classmethod
128
    def load(cls, filepath, map_location="cpu"):
129
        checkpoint = torch.load(filepath, map_location=map_location, weights_only=True)
130
        config = checkpoint["model_config"]
131
        model = cls(**config)
132
        model.load_state_dict(checkpoint["model_state_dict"])
133
        model.eval()
134
        return model
135
136
137
class NNEnsembleBackend(backend.AnnifLearningBackend, ensemble.BaseEnsembleBackend):
138
    """Neural network ensemble backend that combines results from multiple
139
    projects"""
140
141
    name = "nn_ensemble"
142
143
    MODEL_FILE = "nn-model.pt"
144
    LMDB_FILE = "nn-train.mdb"
145
146
    DEFAULT_PARAMETERS = {
147
        "nodes": 100,
148
        "dropout_rate": 0.2,
149
        "optimizer": "adam",
150
        "lr": 0.001,
151
        "epochs": 10,
152
        "learn-epochs": 1,
153
        "lmdb_map_size": 1024 * 1024 * 1024,
154
    }
155
156
    # defaults for uninitialized instances
157
    _model = None
158
159
    def initialize(self, parallel: bool = False) -> None:
160
        super().initialize(parallel)
161
        if self._model is not None:
162
            return  # already initialized
163
        if parallel:
164
            # Don't load model just before parallel execution,
165
            # since it won't work after forking worker processes
166
            return
167
        model_filename = os.path.join(self.datadir, self.MODEL_FILE)
168
        if not os.path.exists(model_filename):
169
            raise NotInitializedException(
170
                "model file {} not found".format(model_filename),
171
                backend_id=self.backend_id,
172
            )
173
        self.debug("loading model from {}".format(model_filename))
174
        try:
175
            self._model = NNEnsembleModel.load(model_filename)
176
        except Exception as err:
177
            message = (
178
                f"loading model from {model_filename}; "
179
                f'original error message: "{err}"'
180
            )
181
            raise OperationFailedException(message, backend_id=self.backend_id)
182
183
    def _merge_source_batches(
184
        self,
185
        batch_by_source: dict[str, SuggestionBatch],
186
        sources: list[tuple[str, float]],
187
        params: dict[str, Any],
188
    ) -> SuggestionBatch:
189
        src_weight = dict(sources)
190
        score_vectors = np.array(
191
            [
192
                [
193
                    np.sqrt(suggestions.as_vector())
194
                    * src_weight[project_id]
195
                    * len(batch_by_source)
196
                    for suggestions in batch
197
                ]
198
                for project_id, batch in batch_by_source.items()
199
            ],
200
            dtype=np.float32,
201
        )
202
        score_vector_tensor = torch.from_numpy(score_vectors.swapaxes(0, 1))
203
        with torch.no_grad():
204
            prediction = self._model(score_vector_tensor)
205
        return SuggestionBatch.from_sequence(
206
            [
207
                vector_to_suggestions(row, limit=int(params["limit"]))
208
                for row in prediction
209
            ],
210
            self.project.subjects,
211
        )
212
213
    def _create_model(self, sources: list[tuple[str, float]]) -> None:
214
        self.info("creating NN ensemble model")
215
216
        # Create PyTorch model
217
        input_dim = len(self.project.subjects) * len(sources)
218
        hidden_dim = int(self.params["nodes"])
219
        output_dim = len(self.project.subjects)
220
        dropout_rate = float(self.params["dropout_rate"])
221
222
        self._model = NNEnsembleModel(
223
            input_dim=input_dim,
224
            hidden_dim=hidden_dim,
225
            output_dim=output_dim,
226
            dropout_rate=dropout_rate,
227
        )
228
229
    #        summary = []
230
    #        self._model.summary(print_fn=summary.append)
231
    #        self.debug("Created model: \n" + "\n".join(summary))
232
233
    def _train(
234
        self,
235
        corpus: DocumentCorpus,
236
        params: dict[str, Any],
237
        jobs: int = 0,
238
    ) -> None:
239
        sources = annif.util.parse_sources(self.params["sources"])
240
        self._create_model(sources)
241
        self._fit_model(
242
            corpus,
243
            epochs=int(params["epochs"]),
244
            lmdb_map_size=int(params["lmdb_map_size"]),
245
            n_jobs=jobs,
246
        )
247
248
    def _corpus_to_vectors(
249
        self,
250
        corpus: DocumentCorpus,
251
        seq: LMDBDataset,
252
        n_jobs: int,
253
    ) -> None:
254
        # pass corpus through all source projects
255
        sources = dict(annif.util.parse_sources(self.params["sources"]))
256
257
        # initialize the source projects before forking, to save memory
258
        self.info(f"Initializing source projects: {', '.join(sources.keys())}")
259
        for project_id in sources.keys():
260
            project = self.project.registry.get_project(project_id)
261
            project.initialize(parallel=True)
262
263
        psmap = annif.parallel.ProjectSuggestMap(
264
            self.project.registry,
265
            list(sources.keys()),
266
            backend_params=None,
267
            limit=None,
268
            threshold=0.0,
269
        )
270
271
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
272
273
        self.info("Processing training documents...")
274
        with pool_class(jobs) as pool:
275
            for hits, subject_set in pool.imap_unordered(
276
                psmap.suggest, corpus.documents
277
            ):
278
                doc_scores = []
279
                for project_id, p_hits in hits.items():
280
                    vector = p_hits.as_vector()
281
                    doc_scores.append(
282
                        np.sqrt(vector) * sources[project_id] * len(sources)
283
                    )
284
                score_vector = np.array(doc_scores, dtype=np.float32)
285
                true_vector = subject_set.as_vector(len(self.project.subjects))
286
                seq.add_sample(score_vector, true_vector)
287
288
    def _open_lmdb(self, cached, lmdb_map_size):
289
        lmdb_path = os.path.join(self.datadir, self.LMDB_FILE)
290
        if not cached and os.path.exists(lmdb_path):
291
            shutil.rmtree(lmdb_path)
292
        return lmdb.open(lmdb_path, map_size=lmdb_map_size, writemap=True, mode=0o775)
293
294
    def _fit_model(
295
        self,
296
        corpus: DocumentCorpus,
297
        epochs: int,
298
        lmdb_map_size: int,
299
        n_jobs: int = 1,
300
    ) -> None:
301
        env = self._open_lmdb(corpus == "cached", lmdb_map_size)
302
        if corpus != "cached":
303
            if corpus.is_empty():
304
                raise NotSupportedException(
305
                    "Cannot train nn_ensemble project with no documents"
306
                )
307
            with env.begin(write=True, buffers=True) as txn:
308
                seq = LMDBDataset(txn)
309
                self._corpus_to_vectors(corpus, seq, n_jobs)
310
        else:
311
            self.info("Reusing cached training data from previous run.")
312
313
        # fit the model using a read-only view of the LMDB
314
        self.info("Training neural network model...")
315
        with env.begin(buffers=True) as txn:
316
            dataset = LMDBDataset(txn)
317
            dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
318
319
            # Training loop
320
            optimizer = torch.optim.Adam(
321
                self._model.parameters(), lr=float(self.params["lr"])
322
            )
323
            criterion = nn.BCEWithLogitsLoss()
324
325
            self._model.train()
326
            for epoch in range(epochs):
327
                tqdm_loader = tqdm(
328
                    dataloader,
329
                    desc=f"Epoch {epoch + 1}/{epochs}",
330
                    postfix={"loss": "0.000"},
331
                )
332
                for inputs, targets in tqdm_loader:
333
                    optimizer.zero_grad()
334
                    outputs = self._model(inputs)
335
                    loss = criterion(outputs, targets)
336
                    loss.backward()
337
                    optimizer.step()
338
                    tqdm_loader.set_postfix(loss=loss.item())
339
                    tqdm_loader.update()
340
341
        annif.util.atomic_save(self._model, self.datadir, self.MODEL_FILE)
342
343
    def _learn(
344
        self,
345
        corpus: DocumentCorpus,
346
        params: dict[str, Any],
347
    ) -> None:
348
        self.initialize()
349
        self._fit_model(
350
            corpus, int(params["learn-epochs"]), int(params["lmdb_map_size"])
351
        )
352