Passed
Pull Request — master (#540)
by
unknown
03:31
created

XTransformerBackend._create_model()   A

Complexity

Conditions 1

Size

Total Lines 34
Code Lines 31

Duplication

Lines 34
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
eloc 31
nop 3
dl 34
loc 34
rs 9.1359
c 0
b 0
f 0
1
"""Annif backend using the transformer variant of pecos."""
2
3
from sys import stdout
4
import os.path as osp
5
import logging
6
import scipy.sparse as sp
7
import numpy as np
8
9
from annif.exception import NotInitializedException, NotSupportedException
10
from annif.suggestion import ListSuggestionResult, SubjectSuggestion
11
from . import mixins
12
from . import backend
13
from annif.util import boolean, apply_param_parse_config, atomic_save_folder, \
14
    atomic_save
15
16
from pecos.xmc.xtransformer.model import XTransformer
17
from pecos.xmc.xtransformer.module import MLProblemWithText
18
from pecos.utils.featurization.text.preprocess import Preprocessor
19
from pecos.xmc.xtransformer import matcher
20
21
22 View Code Duplication
class XTransformerBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
23
    """XTransformer based backend for Annif"""
24
    name = 'xtransformer'
25
    needs_subject_index = True
26
27
    _model = None
28
29
    train_X_file = 'xtransformer-train-X.npz'
30
    train_y_file = 'xtransformer-train-y.npz'
31
    train_txt_file = 'xtransformer-train-raw.txt'
32
    model_folder = 'xtransformer-model'
33
34
    PARAM_CONFIG = {
35
        'min_df': int,
36
        'ngram': int,
37
        'fix_clustering': boolean,
38
        'nr_splits': int,
39
        'min_codes': int,
40
        'max_leaf_size': int,
41
        'imbalanced_ratio': float,
42
        'imbalanced_depth': int,
43
        'max_match_clusters': int,
44
        'do_fine_tune': boolean,
45
        'model_shortcut': str,
46
        'beam_size': int,
47
        'limit': int,
48
        'post_processor': str,
49
        'negative_sampling': str,
50
        'ensemble_method': str,
51
        'threshold': float,
52
        'loss_function': str,
53
        'truncate_length': int,
54
        'hidden_droput_prob': float,
55
        'batch_size': int,
56
        'gradient_accumulation_steps': int,
57
        'learning_rate': float,
58
        'weight_decay': float,
59
        'adam_epsilon': float,
60
        'num_train_epochs': int,
61
        'max_steps': int,
62
        'lr_schedule': str,
63
        'warmup_steps': int,
64
        'logging_steps': int,
65
        'save_steps': int,
66
        'max_active_matching_labels': int,
67
        'max_num_labels_in_gpu': int,
68
        'use_gpu': boolean,
69
        'bootstrap_model': str
70
    }
71
72
    DEFAULT_PARAMETERS = {
73
        'min_df': 1,
74
        'ngram': 1,
75
        'fix_clustering': False,
76
        'nr_splits': 16,
77
        'min_codes': None,
78
        'max_leaf_size': 100,
79
        'imbalanced_ratio': 0.0,
80
        'imbalanced_depth': 100,
81
        'max_match_clusters': 32768,
82
        'do_fine_tune': True,
83
        # 'model_shortcut': 'distilbert-base-multilingual-cased',
84
        'model_shortcut': 'bert-base-multilingual-uncased',
85
        'beam_size': 20,
86
        'limit': 100,
87
        'post_processor': 'sigmoid',
88
        'negative_sampling': 'tfn',
89
        'ensemble_method': 'transformer-only',
90
        'threshold': 0.1,
91
        'loss_function': 'squared-hinge',
92
        'truncate_length': 128,
93
        'hidden_droput_prob': 0.1,
94
        'batch_size': 32,
95
        'gradient_accumulation_steps': 1,
96
        'learning_rate': 1e-4,
97
        'weight_decay': 0.0,
98
        'adam_epsilon': 1e-8,
99
        'num_train_epochs': 1,
100
        'max_steps': 0,
101
        'lr_schedule': 'linear',
102
        'warmup_steps': 0,
103
        'logging_steps': 100,
104
        'save_steps': 1000,
105
        'max_active_matching_labels': None,
106
        'max_num_labels_in_gpu': 65536,
107
        'use_gpu': True,
108
        'bootstrap_model': 'linear'
109
    }
110
111
    def _initialize_model(self):
112
        if self._model is None:
113
            path = osp.join(self.datadir, self.model_folder)
114
            self.debug('loading model from {}'.format(path))
115
            if osp.exists(path):
116
                self._model = XTransformer.load(path)
117
            else:
118
                raise NotInitializedException(
119
                    'model {} not found'.format(path),
120
                    backend_id=self.backend_id)
121
122
    def initialize(self, parallel=False):
123
        self.initialize_vectorizer()
124
        self._initialize_model()
125
126
    def default_params(self):
127
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
128
        params.update(self.DEFAULT_PARAMETERS)
129
        return params
130
131
    def _create_train_files(self, veccorpus, corpus):
132
        self.info('creating train file')
133
        Xs = []
134
        ys = []
135
        txt_pth = osp.join(self.datadir, self.train_txt_file)
136
        with open(txt_pth, 'w',  encoding='utf-8') as txt_file:
137
            for doc, vector in zip(corpus.documents, veccorpus):
138
                subject_ids = [
139
                    self.project.subjects.by_uri(uri)
140
                    for uri
141
                    in doc.uris]
142
                subject_ids = [s_id for s_id in subject_ids if s_id]
143
                if not (subject_ids and doc.text):
144
                    continue # noqa
145
                print(' '.join(doc.text.split()), file=txt_file)
146
                Xs.append(
147
                    sp.csr_matrix(vector, dtype=np.float32).sorted_indices())
148
                ys.append(
149
                    sp.csr_matrix((
150
                        np.ones(len(subject_ids)),
151
                        (
152
                            np.zeros(len(subject_ids)),
153
                            subject_ids)),
154
                        shape=(1, len(self.project.subjects)),
155
                        dtype=np.float32
156
                        ).sorted_indices())
157
        atomic_save(
158
            sp.vstack(Xs, format='csr'),
159
            self.datadir,
160
            self.train_X_file,
161
            method=lambda mtrx, target: sp.save_npz(
162
                target,
163
                mtrx,
164
                compressed=True))
165
        atomic_save(
166
            sp.vstack(ys, format='csr'),
167
            self.datadir,
168
            self.train_y_file,
169
            method=lambda mtrx, target: sp.save_npz(
170
                target,
171
                mtrx,
172
                compressed=True))
173
174
    def _create_model(self, params, jobs):
175
        train_txts = Preprocessor.load_data_from_file(
176
            osp.join(self.datadir, self.train_txt_file),
177
            label_text_path=None,
178
            text_pos=0)['corpus']
179
        train_X = sp.load_npz(osp.join(self.datadir, self.train_X_file))
180
        train_y = sp.load_npz(osp.join(self.datadir, self.train_y_file))
181
        model_path = osp.join(self.datadir, self.model_folder)
182
        new_params = apply_param_parse_config(
183
            self.PARAM_CONFIG,
184
            self.params)
185
        new_params['only_topk'] = new_params.pop('limit')
186
        train_params = XTransformer.TrainParams.from_dict(
187
            new_params,
188
            recursive=True).to_dict()
189
        pred_params = XTransformer.PredParams.from_dict(
190
            new_params,
191
            recursive=True).to_dict()
192
193
        self.info('Start training')
194
        # enable progress
195
        matcher.LOGGER.setLevel(logging.INFO)
196
        matcher.LOGGER.addHandler(logging.StreamHandler(stream=stdout))
197
        self._model = XTransformer.train(
198
            MLProblemWithText(train_txts, train_y, X_feat=train_X),
199
            clustering=None,
200
            val_prob=None,
201
            train_params=train_params,
202
            pred_params=pred_params,
203
            beam_size=params['beam_size'],
204
            steps_scale=None,
205
            label_feat=None,
206
            )
207
        atomic_save_folder(self._model, model_path)
208
209
    def _train(self, corpus, params, jobs=0):
210
        if corpus == 'cached':
211
            self.info("Reusing cached training data from previous run.")
212
        else:
213
            if corpus.is_empty():
214
                raise NotSupportedException(
215
                    'Cannot t project with no documents')
216
            input = (doc.text for doc in corpus.documents)
217
            vecparams = {'min_df': int(params['min_df']),
218
                         'tokenizer': self.project.analyzer.tokenize_words,
219
                         'ngram_range': (1, int(params['ngram']))}
220
            veccorpus = self.create_vectorizer(input, vecparams)
221
            self._create_train_files(veccorpus, corpus)
222
        self._create_model(params, jobs)
223
224
    def _suggest(self, text, params):
225
        text = ' '.join(text.split())
226
        vector = self.vectorizer.transform([text])
227
        if vector.nnz == 0:  # All zero vector, empty result
228
            return ListSuggestionResult([])
229
        new_params = apply_param_parse_config(
230
            self.PARAM_CONFIG,
231
            params
232
        )
233
        prediction = self._model.predict(
234
            [text],
235
            X_feat=vector.sorted_indices(),
236
            batch_size=new_params['batch_size'],
237
            use_gpu=False,
238
            only_top_k=new_params['limit'],
239
            post_processor=new_params['post_processor'])
240
        results = []
241
        for idx, score in zip(prediction.indices, prediction.data):
242
            subject = self.project.subjects[idx]
243
            results.append(SubjectSuggestion(
244
                uri=subject[0],
245
                label=subject[1],
246
                notation=subject[2],
247
                score=score
248
            ))
249
        return ListSuggestionResult(results)
250