Passed
Pull Request — main (#914)
by
unknown
04:33
created

annif.backend.ebm.EbmBackend.initialize()   A

Complexity

Conditions 3

Size

Total Lines 12
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 10
nop 2
dl 0
loc 12
rs 9.9
c 0
b 0
f 0
1
import os
2
from typing import Any
3
4
import joblib
5
import numpy as np
6
from ebm4subjects.ebm_model import EbmModel
7
8
from annif.analyzer.analyzer import Analyzer
9
from annif.corpus.document import Document, DocumentCorpus
10
from annif.exception import NotInitializedException, NotSupportedException
11
from annif.suggestion import SuggestionBatch, vector_to_suggestions
12
from annif.util import atomic_save
13
14
from . import backend
15
16
17
class EbmBackend(backend.AnnifBackend):
18
    name = "ebm"
19
20
    EBM_PARAMETERS = {
21
        "embedding_dimensions": int,
22
        "max_chunk_count": int,
23
        "max_chunk_length": int,
24
        "chunking_jobs": int,
25
        "max_sentence_count": int,
26
        "hnsw_index_params": dict[str, Any],
27
        "candidates_per_chunk": int,
28
        "candidates_per_doc": int,
29
        "query_jobs": int,
30
        "xgb_shrinkage": float,
31
        "xgb_interaction_depth": int,
32
        "xgb_subsample": float,
33
        "xgb_rounds": int,
34
        "xgb_jobs": int,
35
        "duck_db_threads": int,
36
        "use_altLabels": bool,
37
        "embedding_model_name": str,
38
        "embedding_model_deployment": str,
39
        "embedding_model_args": dict[str, Any],
40
        "encode_args_vocab": dict[str, Any],
41
        "encode_args_documents": dict[str, Any],
42
    }
43
44
    DEFAULT_PARAMETERS = {
45
        "embedding_dimensions": 1024,
46
        "max_chunk_count": 100,
47
        "max_chunk_length": 50,
48
        "chunking_jobs": 1,
49
        "max_sentence_count": 100,
50
        "hnsw_index_params": {"M": 32, "ef_construction": 256, "ef_search": 256},
51
        "candidates_per_chunk": 20,
52
        "candidates_per_doc": 100,
53
        "query_jobs": 1,
54
        "xgb_shrinkage": 0.03,
55
        "xgb_interaction_depth": 5,
56
        "xgb_subsample": 0.7,
57
        "xgb_rounds": 300,
58
        "xgb_jobs": 1,
59
        "duckdb_threads": 1,
60
        "use_altLabels": True,
61
        "embedding_model_name": None,
62
        "embedding_model_deployment": "mock",
63
        "embedding_model_args": None,
64
        "encode_args_vocab": None,
65
        "encode_args_documents": None,
66
    }
67
68
    DB_FILE = "ebm-duck.db"
69
    MODEL_FILE = "ebm-model.gz"
70
    TRAIN_FILE = "ebm-train.gz"
71
72
    _analyzer = Analyzer()
73
74
    _model = None
75
76
    def initialize(self, parallel: bool = False) -> None:
77
        if self._model is None:
78
            path = os.path.join(self.datadir, self.MODEL_FILE)
79
80
            self.debug(f"loading model from {path}")
81
            if os.path.exists(path):
82
                self._model = EbmModel.load(path)
83
                self._model.init_logger(logger=self)
84
                self.debug("loaded model")
85
            else:
86
                raise NotInitializedException(
87
                    f"model not found at {path}", backend_id=self.backend_id
88
                )
89
90
    def _train(
91
        self,
92
        corpus: DocumentCorpus,
93
        params: dict[str, Any],
94
        jobs: int = 0,
95
    ) -> None:
96
        self.info("starting train")
97
        self._model = EbmModel(
98
            db_path=os.path.join(self.datadir, self.DB_FILE),
99
            embedding_dimensions=params["embedding_dimensions"],
100
            chunk_tokenizer=self._analyzer,
101
            max_chunk_count=params["max_chunk_count"],
102
            max_chunk_length=params["max_chunk_length"],
103
            chunking_jobs=params["chunking_jobs"],
104
            max_sentence_count=params["max_sentence_count"],
105
            hnsw_index_params=params["hnsw_index_params"],
106
            candidates_per_chunk=params["candidates_per_chunk"],
107
            candidates_per_doc=params["candidates_per_doc"],
108
            query_jobs=params["query_jobs"],
109
            xgb_shrinkage=params["xgb_shrinkage"],
110
            xgb_interaction_depth=params["xgb_interaction_depth"],
111
            xgb_subsample=params["xgb_subsample"],
112
            xgb_rounds=params["xgb_rounds"],
113
            xgb_jobs=params["xgb_jobs"],
114
            duckdb_threads=jobs if jobs else params["duckdb_threads"],
115
            use_altLabels=params["use_altLabels"],
116
            embedding_model_name=params["embedding_model_name"],
117
            embedding_model_deployment=params["embedding_model_deployment"],
118
            embedding_model_args=params["embedding_model_args"],
119
            encode_args_vocab=params["encode_args_vocab"],
120
            encode_args_documents=params["encode_args_documents"],
121
            logger=self,
122
        )
123
124
        if corpus != "cached":
125
            if corpus.is_empty():
126
                raise NotSupportedException(
127
                    f"training backend {self.backend_id} with no documents"
128
                )
129
130
            self.info("creating vector database")
131
            self._model.create_vector_db(
132
                vocab_in_path=os.path.join(
133
                    self.project.vocab.datadir, self.project.vocab.INDEX_FILENAME_TTL
134
                ),
135
                force=True,
136
            )
137
138
            self.info("preparing training data")
139
            doc_ids = []
140
            texts = []
141
            label_ids = []
142
            for doc_id, doc in enumerate(corpus.documents):
143
                for subject_id in [
144
                    subject_id for subject_id in getattr(doc, "subject_set")
145
                ]:
146
                    doc_ids.append(doc_id)
147
                    texts.append(getattr(doc, "text"))
148
                    label_ids.append(self.project.subjects[subject_id].uri)
149
150
            train_data = self._model.prepare_train(
151
                doc_ids=doc_ids,
152
                label_ids=label_ids,
153
                texts=texts,
154
                n_jobs=jobs,
155
            )
156
157
            atomic_save(
158
                obj=train_data,
159
                dirname=self.datadir,
160
                filename=self.TRAIN_FILE,
161
                method=joblib.dump,
162
            )
163
164
        else:
165
            self.info("reusing cached training data from previous run")
166
            if not os.path.exists(self._model.db_path):
167
                raise NotInitializedException(
168
                    f"database file {self._model.db_path} not found",
169
                    backend_id=self.backend_id,
170
                )
171
            if not os.path.exists(os.path.join(self.datadir, self.TRAIN_FILE)):
172
                raise NotInitializedException(
173
                    f"train data file {self.TRAIN_FILE} not found",
174
                    backend_id=self.backend_id,
175
                )
176
177
            train_data = joblib.load(os.path.join(self.datadir, self.TRAIN_FILE))
178
179
        self.info("training model")
180
        self._model.train(train_data, jobs)
181
182
        self.info("saving model")
183
        atomic_save(self._model, self.datadir, self.MODEL_FILE)
184
185
    def _suggest_batch(
186
        self, documents: list[Document], params: dict[str, Any]
187
    ) -> SuggestionBatch:
188
        candidates = self._model.generate_candidates_batch(
189
            texts=[doc.text for doc in documents],
190
            doc_ids=[i for i in range(len(documents))],
191
        )
192
193
        predictions = self._model.predict(candidates)
194
195
        suggestions = []
196
        for doc_predictions in predictions:
197
            vector = np.zeros(len(self.project.subjects), dtype=np.float32)
198
            for row in doc_predictions.iter_rows(named=True):
199
                position = self.project.subjects._uri_idx.get(row["label_id"], 0)
200
                vector[position] = row["score"]
201
            suggestions.append(vector_to_suggestions(vector, int(params["limit"])))
202
203
        return SuggestionBatch.from_sequence(
204
            suggestions,
205
            self.project.subjects,
206
            limit=int(params.get("limit")),
207
        )
208