Passed
Pull Request — main (#914)
by
unknown
02:58
created

annif.backend.ebm.EbmBackend._train()   C

Complexity

Conditions 8

Size

Total Lines 93
Code Lines 73

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 73
nop 4
dl 0
loc 93
rs 6.0169
c 0
b 0
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
import os
2
from typing import Any
3
4
import joblib
5
import numpy as np
6
from ebm4subjects.ebm_model import EbmModel
7
8
from annif.analyzer.analyzer import Analyzer
9
from annif.corpus.document import Document, DocumentCorpus
10
from annif.exception import NotInitializedException, NotSupportedException
11
from annif.suggestion import SuggestionBatch, vector_to_suggestions
12
from annif.util import atomic_save
13
14
from . import backend
15
16
17
class EbmBackend(backend.AnnifBackend):
18
    name = "ebm"
19
20
    EBM_PARAMETERS = {
21
        "embedding_model_name": str,
22
        "embedding_dimensions": int,
23
        "max_chunk_count": int,
24
        "max_chunk_length": int,
25
        "chunking_jobs": int,
26
        "max_sentence_count": int,
27
        "hnsw_index_params": dict[str, Any],
28
        "candidates_per_chunk": int,
29
        "candidates_per_doc": int,
30
        "query_jobs": int,
31
        "xgb_shrinkage": float,
32
        "xgb_interaction_depth": int,
33
        "xgb_subsample": float,
34
        "xgb_rounds": int,
35
        "xgb_jobs": int,
36
        "duck_db_threads": int,
37
        "use_altLabels": bool,
38
        "model_args": dict[str, Any],
39
        "encode_args_vocab": dict[str, Any],
40
        "encode_args_documents": dict[str, Any],
41
    }
42
43
    DEFAULT_PARAMETERS = {
44
        "embedding_model_name": "BAAI/bge-m3",
45
        "embedding_dimensions": 1024,
46
        "max_chunk_count": 100,
47
        "max_chunk_length": 50,
48
        "chunking_jobs": 1,
49
        "max_sentence_count": 100,
50
        "hnsw_index_params": {"M": 32, "ef_construction": 256, "ef_search": 256},
51
        "candidates_per_chunk": 20,
52
        "candidates_per_doc": 100,
53
        "query_jobs": 1,
54
        "xgb_shrinkage": 0.03,
55
        "xgb_interaction_depth": 5,
56
        "xgb_subsample": 0.7,
57
        "xgb_rounds": 300,
58
        "xgb_jobs": 1,
59
        "duckdb_threads": 1,
60
        "use_altLabels": True,
61
        "model_args": {"device": "cpu", "trust_remote_code": False},
62
        "encode_args_vocab": {"batch_size": 32, "show_progress_bar": True},
63
        "encode_args_documents": {"batch_size": 32, "show_progress_bar": True},
64
    }
65
66
    DB_FILE = "ebm-duck.db"
67
    MODEL_FILE = "ebm-model.gz"
68
    TRAIN_FILE = "ebm-train.gz"
69
70
    _analyzer = Analyzer()
71
72
    _model = None
73
74
    def initialize(self, parallel: bool = False) -> None:
75
        if self._model is None:
76
            path = os.path.join(self.datadir, self.MODEL_FILE)
77
78
            self.debug(f"loading model from {path}")
79
            if os.path.exists(path):
80
                self._model = EbmModel.load(path)
81
                self._model.init_logger(logger=self)
82
                self.debug("loaded model")
83
            else:
84
                raise NotInitializedException(
85
                    f"model not found at {path}", backend_id=self.backend_id
86
                )
87
88
    def _train(
89
        self,
90
        corpus: DocumentCorpus,
91
        params: dict[str, Any],
92
        jobs: int = 0,
93
    ) -> None:
94
        self.info("starting train")
95
        self._model = EbmModel(
96
            db_path=os.path.join(self.datadir, self.DB_FILE),
97
            embedding_model_name=params["embedding_model_name"],
98
            embedding_dimensions=params["embedding_dimensions"],
99
            chunk_tokenizer=self._analyzer,
100
            max_chunk_count=params["max_chunk_count"],
101
            max_chunk_length=params["max_chunk_length"],
102
            chunking_jobs=params["chunking_jobs"],
103
            max_sentence_count=params["max_sentence_count"],
104
            hnsw_index_params=params["hnsw_index_params"],
105
            candidates_per_chunk=params["candidates_per_chunk"],
106
            candidates_per_doc=params["candidates_per_doc"],
107
            query_jobs=params["query_jobs"],
108
            xgb_shrinkage=params["xgb_shrinkage"],
109
            xgb_interaction_depth=params["xgb_interaction_depth"],
110
            xgb_subsample=params["xgb_subsample"],
111
            xgb_rounds=params["xgb_rounds"],
112
            xgb_jobs=params["xgb_jobs"],
113
            duckdb_threads=jobs if jobs else params["duckdb_threads"],
114
            use_altLabels=params["use_altLabels"],
115
            model_args=params["model_args"],
116
            encode_args_vocab=params["encode_args_vocab"],
117
            encode_args_documents=params["encode_args_documents"],
118
            logger=self,
119
        )
120
121
        if corpus != "cached":
122
            if corpus.is_empty():
123
                raise NotSupportedException(
124
                    f"training backend {self.backend_id} with no documents"
125
                )
126
127
            self.info("creating vector database")
128
            self._model.create_vector_db(
129
                vocab_in_path=os.path.join(
130
                    self.project.vocab.datadir, self.project.vocab.INDEX_FILENAME_TTL
131
                ),
132
                force=True,
133
            )
134
135
            self.info("preparing training data")
136
            doc_ids = []
137
            texts = []
138
            label_ids = []
139
            for doc_id, doc in enumerate(corpus.documents):
140
                for subject_id in [
141
                    subject_id for subject_id in getattr(doc, "subject_set")
142
                ]:
143
                    doc_ids.append(doc_id)
144
                    texts.append(getattr(doc, "text"))
145
                    label_ids.append(self.project.subjects[subject_id].uri)
146
147
            train_data = self._model.prepare_train(
148
                doc_ids=doc_ids,
149
                label_ids=label_ids,
150
                texts=texts,
151
                n_jobs=jobs,
152
            )
153
154
            atomic_save(
155
                obj=train_data,
156
                dirname=self.datadir,
157
                filename=self.TRAIN_FILE,
158
                method=joblib.dump,
159
            )
160
161
        else:
162
            self.info("reusing cached training data from previous run")
163
            if not os.path.exists(self._model.db_path):
164
                raise NotInitializedException(
165
                    f"database file {self._model.db_path} not found",
166
                    backend_id=self.backend_id,
167
                )
168
            if not os.path.exists(os.path.join(self.datadir, self.TRAIN_FILE)):
169
                raise NotInitializedException(
170
                    f"train data file {self.TRAIN_FILE} not found",
171
                    backend_id=self.backend_id,
172
                )
173
174
            train_data = joblib.load(os.path.join(self.datadir, self.TRAIN_FILE))
175
176
        self.info("training model")
177
        self._model.train(train_data, jobs)
178
179
        self.info("saving model")
180
        atomic_save(self._model, self.datadir, self.MODEL_FILE)
181
182
    def _suggest_batch(
183
        self, documents: list[Document], params: dict[str, Any]
184
    ) -> SuggestionBatch:
185
        candidates = self._model.generate_candidates_batch(
186
            texts=[doc.text for doc in documents],
187
            doc_ids=[i for i in range(len(documents))],
188
        )
189
190
        predictions = self._model.predict(candidates)
191
192
        suggestions = []
193
        for doc_predictions in predictions:
194
            vector = np.zeros(len(self.project.subjects), dtype=np.float32)
195
            for row in doc_predictions.iter_rows(named=True):
196
                position = self.project.subjects._uri_idx.get(row["label_id"], 0)
197
                vector[position] = row["score"]
198
            suggestions.append(vector_to_suggestions(vector, int(params["limit"])))
199
200
        return SuggestionBatch.from_sequence(
201
            suggestions,
202
            self.project.subjects,
203
            limit=int(params.get("limit")),
204
        )
205