Passed
Pull Request — main (#798)
by
unknown
03:19
created

XTransformerBackend._suggest_batch()   A

Complexity

Conditions 4

Size

Total Lines 25
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 22
nop 3
dl 0
loc 25
rs 9.352
c 0
b 0
f 0
1
"""Annif backend using the transformer variant of pecos."""
2
3
import logging
4
import os.path as osp
5
from sys import stdout
6
from typing import TYPE_CHECKING, Any
7
8
import numpy as np
9
import scipy.sparse as sp
10
from pecos.utils.featurization.text.preprocess import Preprocessor
11
from pecos.xmc.xtransformer import matcher, model
12
from pecos.xmc.xtransformer.model import XTransformer
13
from pecos.xmc.xtransformer.module import MLProblemWithText
14
15
from annif.exception import NotInitializedException, NotSupportedException
16
from annif.suggestion import SuggestionBatch, SubjectSuggestion, vector_to_suggestions
17
from annif.util import (
18
    apply_param_parse_config,
19
    atomic_save,
20
    atomic_save_folder,
21
    boolean,
22
)
23
24
from . import backend, mixins
25
26
27
# if TYPE_CHECKING:
28
from collections.abc import Iterator
29
30
from scipy.sparse._csr import csr_matrix
31
32
from annif.corpus.document import DocumentCorpus
33
34
class XTransformerBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
35
    """XTransformer based backend for Annif"""
36
37
    name = "xtransformer"
38
    needs_subject_index = True
39
40
    _model = None
41
42
    train_X_file = "xtransformer-train-X.npz"
43
    train_y_file = "xtransformer-train-y.npz"
44
    train_txt_file = "xtransformer-train-raw.txt"
45
    model_folder = "xtransformer-model"
46
47
    PARAM_CONFIG = {
48
        "min_df": int,
49
        "ngram": int,
50
        "fix_clustering": boolean,
51
        "nr_splits": int,
52
        "min_codes": int,
53
        "max_leaf_size": int,
54
        "imbalanced_ratio": float,
55
        "imbalanced_depth": int,
56
        "max_match_clusters": int,
57
        "do_fine_tune": boolean,
58
        "model_shortcut": str,
59
        "beam_size": int,
60
        "limit": int,
61
        "post_processor": str,
62
        "negative_sampling": str,
63
        "ensemble_method": str,
64
        "threshold": float,
65
        "loss_function": str,
66
        "truncate_length": int,
67
        "hidden_droput_prob": float,
68
        "batch_size": int,
69
        "gradient_accumulation_steps": int,
70
        "learning_rate": float,
71
        "weight_decay": float,
72
        "adam_epsilon": float,
73
        "num_train_epochs": int,
74
        "max_steps": int,
75
        "lr_schedule": str,
76
        "warmup_steps": int,
77
        "logging_steps": int,
78
        "save_steps": int,
79
        "max_active_matching_labels": int,
80
        "max_num_labels_in_gpu": int,
81
        "use_gpu": boolean,
82
        "bootstrap_model": str,
83
    }
84
85
    DEFAULT_PARAMETERS = {
86
        "min_df": 1,
87
        "ngram": 1,
88
        "fix_clustering": False,
89
        "nr_splits": 16,
90
        "min_codes": None,
91
        "max_leaf_size": 100,
92
        "imbalanced_ratio": 0.0,
93
        "imbalanced_depth": 100,
94
        "max_match_clusters": 32768,
95
        "do_fine_tune": True,
96
        "model_shortcut": "distilbert-base-multilingual-uncased",
97
        "beam_size": 20,
98
        "limit": 100,
99
        "post_processor": "sigmoid",
100
        "negative_sampling": "tfn",
101
        "ensemble_method": "transformer-only",
102
        "threshold": 0.1,
103
        "loss_function": "squared-hinge",
104
        "truncate_length": 128,
105
        "hidden_droput_prob": 0.1,
106
        "batch_size": 32,
107
        "gradient_accumulation_steps": 1,
108
        "learning_rate": 1e-4,
109
        "weight_decay": 0.0,
110
        "adam_epsilon": 1e-8,
111
        "num_train_epochs": 1,
112
        "max_steps": 0,
113
        "lr_schedule": "linear",
114
        "warmup_steps": 0,
115
        "logging_steps": 100,
116
        "save_steps": 1000,
117
        "max_active_matching_labels": None,
118
        "max_num_labels_in_gpu": 65536,
119
        "use_gpu": True,
120
        "bootstrap_model": "linear",
121
    }
122
123
    def _initialize_model(self):
124
        if self._model is None:
125
            path = osp.join(self.datadir, self.model_folder)
126
            self.debug("loading model from {}".format(path))
127
            if osp.exists(path):
128
                self._model = XTransformer.load(path)
129
            else:
130
                raise NotInitializedException(
131
                    "model {} not found".format(path), backend_id=self.backend_id
132
                )
133
134
    def initialize(self, parallel: bool = False) -> None:
135
        self.initialize_vectorizer()
136
        self._initialize_model()
137
138
    def default_params(self):
139
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
140
        params.update(self.DEFAULT_PARAMETERS)
141
        return params
142
143
    def _create_train_files(self, veccorpus, corpus):
144
        self.info("creating train file")
145
        Xs = []
146
        ys = []
147
        txt_pth = osp.join(self.datadir, self.train_txt_file)
148
        with open(txt_pth, "w", encoding="utf-8") as txt_file:
149
            for doc, vector in zip(corpus.documents, veccorpus):
150
                subject_set = doc.subject_set
151
                if not (subject_set and doc.text):
152
                    continue  # noqa
153
                print(" ".join(doc.text.split()), file=txt_file)
154
                Xs.append(sp.csr_matrix(vector, dtype=np.float32).sorted_indices())
155
                ys.append(
156
                    sp.csr_matrix(
157
                        (
158
                            np.ones(len(subject_set)),
159
                            (np.zeros(len(subject_set)), [s for s in subject_set]),
160
                        ),
161
                        shape=(1, len(self.project.subjects)),
162
                        dtype=np.float32,
163
                    ).sorted_indices()
164
                )
165
        atomic_save(
166
            sp.vstack(Xs, format="csr"),
167
            self.datadir,
168
            self.train_X_file,
169
            method=lambda mtrx, target: sp.save_npz(target, mtrx, compressed=True),
170
        )
171
        atomic_save(
172
            sp.vstack(ys, format="csr"),
173
            self.datadir,
174
            self.train_y_file,
175
            method=lambda mtrx, target: sp.save_npz(target, mtrx, compressed=True),
176
        )
177
178
    def _create_model(self, params, jobs):
179
        train_txts = Preprocessor.load_data_from_file(
180
            osp.join(self.datadir, self.train_txt_file),
181
            label_text_path=None,
182
            text_pos=0,
183
        )["corpus"]
184
        train_X = sp.load_npz(osp.join(self.datadir, self.train_X_file))
185
        train_y = sp.load_npz(osp.join(self.datadir, self.train_y_file))
186
        model_path = osp.join(self.datadir, self.model_folder)
187
        new_params = apply_param_parse_config(self.PARAM_CONFIG, self.params)
188
        new_params["only_topk"] = new_params.pop("limit")
189
        train_params = XTransformer.TrainParams.from_dict(
190
            new_params, recursive=True
191
        ).to_dict()
192
        pred_params = XTransformer.PredParams.from_dict(
193
            new_params, recursive=True
194
        ).to_dict()
195
196
        self.info("Start training")
197
        # enable progress
198
        matcher.LOGGER.setLevel(logging.DEBUG)
199
        matcher.LOGGER.addHandler(logging.StreamHandler(stream=stdout))
200
        model.LOGGER.setLevel(logging.DEBUG)
201
        model.LOGGER.addHandler(logging.StreamHandler(stream=stdout))
202
        self._model = XTransformer.train(
203
            MLProblemWithText(train_txts, train_y, X_feat=train_X),
204
            clustering=None,
205
            val_prob=None,
206
            train_params=train_params,
207
            pred_params=pred_params,
208
            beam_size=int(params["beam_size"]),
209
            steps_scale=None,
210
            label_feat=None,
211
        )
212
        atomic_save_folder(self._model, model_path)
213
214 View Code Duplication
    def _train(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
215
        self,
216
        corpus: DocumentCorpus,
217
        params: dict[str, Any],
218
        jobs: int = 0,
219
    ) -> None:
220
        if corpus == "cached":
221
            self.info("Reusing cached training data from previous run.")
222
        else:
223
            if corpus.is_empty():
224
                raise NotSupportedException("Cannot t project with no documents")
225
            input = (doc.text for doc in corpus.documents)
226
            vecparams = {
227
                "min_df": int(params["min_df"]),
228
                "tokenizer": self.project.analyzer.tokenize_words,
229
                "ngram_range": (1, int(params["ngram"])),
230
            }
231
            veccorpus = self.create_vectorizer(input, vecparams)
232
            self._create_train_files(veccorpus, corpus)
233
        self._create_model(params, jobs)
234
235
    def _suggest_batch(
236
        self, texts: list[str], params: dict[str, Any]
237
    ) -> SuggestionBatch:
238
        vector = self.vectorizer.transform(texts)
239
        
240
        if vector.nnz == 0:  # All zero vector, empty result
241
            return list()
242
        new_params = apply_param_parse_config(self.PARAM_CONFIG, params)
243
        prediction = self._model.predict(
244
            texts,
245
            X_feat=vector.sorted_indices(),
246
            batch_size=new_params["batch_size"],
247
            use_gpu=True,
248
            only_top_k=new_params["limit"],
249
            post_processor=new_params["post_processor"],
250
        )
251
        current_batchsize = prediction.get_shape()[0]
252
        batch_result = []
253
        for i in range(current_batchsize):
254
            results = []
255
            row = prediction.getrow(i)
256
            for idx, score in zip(row.indices, row.data):
257
                results.append(SubjectSuggestion(subject_id=idx, score=score))
258
            batch_result.append(results)
259
        return SuggestionBatch.from_sequence(batch_result, self.project.subjects)
260