|
1
|
|
|
import os |
|
2
|
|
|
from typing import Any |
|
3
|
|
|
|
|
4
|
|
|
import joblib |
|
5
|
|
|
import numpy as np |
|
6
|
|
|
from ebm4subjects.ebm_model import EbmModel |
|
7
|
|
|
|
|
8
|
|
|
from annif.analyzer.analyzer import Analyzer |
|
9
|
|
|
from annif.corpus.document import Document, DocumentCorpus |
|
10
|
|
|
from annif.exception import NotInitializedException, NotSupportedException |
|
11
|
|
|
from annif.suggestion import SuggestionBatch, vector_to_suggestions |
|
12
|
|
|
from annif.util import atomic_save |
|
13
|
|
|
|
|
14
|
|
|
from . import backend |
|
15
|
|
|
|
|
16
|
|
|
|
|
17
|
|
|
class EbmBackend(backend.AnnifBackend): |
|
18
|
|
|
name = "ebm" |
|
19
|
|
|
|
|
20
|
|
|
EBM_PARAMETERS = { |
|
21
|
|
|
"embedding_model_name": str, |
|
22
|
|
|
"embedding_dimensions": int, |
|
23
|
|
|
"max_chunk_count": int, |
|
24
|
|
|
"max_chunk_length": int, |
|
25
|
|
|
"chunking_jobs": int, |
|
26
|
|
|
"max_sentence_count": int, |
|
27
|
|
|
"hnsw_index_params": dict[str, Any], |
|
28
|
|
|
"candidates_per_chunk": int, |
|
29
|
|
|
"candidates_per_doc": int, |
|
30
|
|
|
"query_jobs": int, |
|
31
|
|
|
"xgb_shrinkage": float, |
|
32
|
|
|
"xgb_interaction_depth": int, |
|
33
|
|
|
"xgb_subsample": float, |
|
34
|
|
|
"xgb_rounds": int, |
|
35
|
|
|
"xgb_jobs": int, |
|
36
|
|
|
"duck_db_threads": int, |
|
37
|
|
|
"use_altLabels": bool, |
|
38
|
|
|
"model_args": dict[str, Any], |
|
39
|
|
|
"encode_args_vocab": dict[str, Any], |
|
40
|
|
|
"encode_args_documents": dict[str, Any], |
|
41
|
|
|
} |
|
42
|
|
|
|
|
43
|
|
|
DEFAULT_PARAMETERS = { |
|
44
|
|
|
"embedding_model_name": "BAAI/bge-m3", |
|
45
|
|
|
"embedding_dimensions": 1024, |
|
46
|
|
|
"max_chunk_count": 100, |
|
47
|
|
|
"max_chunk_length": 50, |
|
48
|
|
|
"chunking_jobs": 1, |
|
49
|
|
|
"max_sentence_count": 100, |
|
50
|
|
|
"hnsw_index_params": {"M": 32, "ef_construction": 256, "ef_search": 256}, |
|
51
|
|
|
"candidates_per_chunk": 20, |
|
52
|
|
|
"candidates_per_doc": 100, |
|
53
|
|
|
"query_jobs": 1, |
|
54
|
|
|
"xgb_shrinkage": 0.03, |
|
55
|
|
|
"xgb_interaction_depth": 5, |
|
56
|
|
|
"xgb_subsample": 0.7, |
|
57
|
|
|
"xgb_rounds": 300, |
|
58
|
|
|
"xgb_jobs": 1, |
|
59
|
|
|
"duckdb_threads": 1, |
|
60
|
|
|
"use_altLabels": True, |
|
61
|
|
|
"model_args": {"device": "cpu", "trust_remote_code": False}, |
|
62
|
|
|
"encode_args_vocab": {"batch_size": 32, "show_progress_bar": True}, |
|
63
|
|
|
"encode_args_documents": {"batch_size": 32, "show_progress_bar": True}, |
|
64
|
|
|
} |
|
65
|
|
|
|
|
66
|
|
|
DB_FILE = "ebm-duck.db" |
|
67
|
|
|
MODEL_FILE = "ebm-model.gz" |
|
68
|
|
|
TRAIN_FILE = "ebm-train.gz" |
|
69
|
|
|
|
|
70
|
|
|
_analyzer = Analyzer() |
|
71
|
|
|
|
|
72
|
|
|
_model = None |
|
73
|
|
|
|
|
74
|
|
|
def initialize(self, parallel: bool = False) -> None: |
|
75
|
|
|
if self._model is None: |
|
76
|
|
|
path = os.path.join(self.datadir, self.MODEL_FILE) |
|
77
|
|
|
|
|
78
|
|
|
self.debug(f"loading model from {path}") |
|
79
|
|
|
if os.path.exists(path): |
|
80
|
|
|
self._model = EbmModel.load(path) |
|
81
|
|
|
self._model.init_logger(logger=self) |
|
82
|
|
|
self.debug("loaded model") |
|
83
|
|
|
else: |
|
84
|
|
|
raise NotInitializedException( |
|
85
|
|
|
f"model not found at {path}", backend_id=self.backend_id |
|
86
|
|
|
) |
|
87
|
|
|
|
|
88
|
|
|
def _train( |
|
89
|
|
|
self, |
|
90
|
|
|
corpus: DocumentCorpus, |
|
91
|
|
|
params: dict[str, Any], |
|
92
|
|
|
jobs: int = 0, |
|
93
|
|
|
) -> None: |
|
94
|
|
|
self.info("starting train") |
|
95
|
|
|
self._model = EbmModel( |
|
96
|
|
|
db_path=os.path.join(self.datadir, self.DB_FILE), |
|
97
|
|
|
embedding_model_name=params["embedding_model_name"], |
|
98
|
|
|
embedding_dimensions=params["embedding_dimensions"], |
|
99
|
|
|
chunk_tokenizer=self._analyzer, |
|
100
|
|
|
max_chunk_count=params["max_chunk_count"], |
|
101
|
|
|
max_chunk_length=params["max_chunk_length"], |
|
102
|
|
|
chunking_jobs=params["chunking_jobs"], |
|
103
|
|
|
max_sentence_count=params["max_sentence_count"], |
|
104
|
|
|
hnsw_index_params=params["hnsw_index_params"], |
|
105
|
|
|
candidates_per_chunk=params["candidates_per_chunk"], |
|
106
|
|
|
candidates_per_doc=params["candidates_per_doc"], |
|
107
|
|
|
query_jobs=params["query_jobs"], |
|
108
|
|
|
xgb_shrinkage=params["xgb_shrinkage"], |
|
109
|
|
|
xgb_interaction_depth=params["xgb_interaction_depth"], |
|
110
|
|
|
xgb_subsample=params["xgb_subsample"], |
|
111
|
|
|
xgb_rounds=params["xgb_rounds"], |
|
112
|
|
|
xgb_jobs=params["xgb_jobs"], |
|
113
|
|
|
duckdb_threads=jobs if jobs else params["duckdb_threads"], |
|
114
|
|
|
use_altLabels=params["use_altLabels"], |
|
115
|
|
|
model_args=params["model_args"], |
|
116
|
|
|
encode_args_vocab=params["encode_args_vocab"], |
|
117
|
|
|
encode_args_documents=params["encode_args_documents"], |
|
118
|
|
|
logger=self, |
|
119
|
|
|
) |
|
120
|
|
|
|
|
121
|
|
|
if corpus != "cached": |
|
122
|
|
|
if corpus.is_empty(): |
|
123
|
|
|
raise NotSupportedException( |
|
124
|
|
|
f"training backend {self.backend_id} with no documents" |
|
125
|
|
|
) |
|
126
|
|
|
|
|
127
|
|
|
self.info("creating vector database") |
|
128
|
|
|
self._model.create_vector_db( |
|
129
|
|
|
vocab_in_path=os.path.join( |
|
130
|
|
|
self.project.vocab.datadir, self.project.vocab.INDEX_FILENAME_TTL |
|
131
|
|
|
), |
|
132
|
|
|
force=True, |
|
133
|
|
|
) |
|
134
|
|
|
|
|
135
|
|
|
self.info("preparing training data") |
|
136
|
|
|
doc_ids = [] |
|
137
|
|
|
texts = [] |
|
138
|
|
|
label_ids = [] |
|
139
|
|
|
for doc_id, doc in enumerate(corpus.documents): |
|
140
|
|
|
for subject_id in [ |
|
141
|
|
|
subject_id for subject_id in getattr(doc, "subject_set") |
|
142
|
|
|
]: |
|
143
|
|
|
doc_ids.append(doc_id) |
|
144
|
|
|
texts.append(getattr(doc, "text")) |
|
145
|
|
|
label_ids.append(self.project.subjects[subject_id].uri) |
|
146
|
|
|
|
|
147
|
|
|
train_data = self._model.prepare_train( |
|
148
|
|
|
doc_ids=doc_ids, |
|
149
|
|
|
label_ids=label_ids, |
|
150
|
|
|
texts=texts, |
|
151
|
|
|
n_jobs=jobs, |
|
152
|
|
|
) |
|
153
|
|
|
|
|
154
|
|
|
atomic_save( |
|
155
|
|
|
obj=train_data, |
|
156
|
|
|
dirname=self.datadir, |
|
157
|
|
|
filename=self.TRAIN_FILE, |
|
158
|
|
|
method=joblib.dump, |
|
159
|
|
|
) |
|
160
|
|
|
|
|
161
|
|
|
else: |
|
162
|
|
|
self.info("reusing cached training data from previous run") |
|
163
|
|
|
if not os.path.exists(self._model.db_path): |
|
164
|
|
|
raise NotInitializedException( |
|
165
|
|
|
f"database file {self._model.db_path} not found", |
|
166
|
|
|
backend_id=self.backend_id, |
|
167
|
|
|
) |
|
168
|
|
|
if not os.path.exists(os.path.join(self.datadir, self.TRAIN_FILE)): |
|
169
|
|
|
raise NotInitializedException( |
|
170
|
|
|
f"train data file {self.TRAIN_FILE} not found", |
|
171
|
|
|
backend_id=self.backend_id, |
|
172
|
|
|
) |
|
173
|
|
|
|
|
174
|
|
|
train_data = joblib.load(os.path.join(self.datadir, self.TRAIN_FILE)) |
|
175
|
|
|
|
|
176
|
|
|
self.info("training model") |
|
177
|
|
|
self._model.train(train_data, jobs) |
|
178
|
|
|
|
|
179
|
|
|
self.info("saving model") |
|
180
|
|
|
atomic_save(self._model, self.datadir, self.MODEL_FILE) |
|
181
|
|
|
|
|
182
|
|
|
def _suggest_batch( |
|
183
|
|
|
self, documents: list[Document], params: dict[str, Any] |
|
184
|
|
|
) -> SuggestionBatch: |
|
185
|
|
|
candidates = self._model.generate_candidates_batch( |
|
186
|
|
|
texts=[doc.text for doc in documents], |
|
187
|
|
|
doc_ids=[i for i in range(len(documents))], |
|
188
|
|
|
) |
|
189
|
|
|
|
|
190
|
|
|
predictions = self._model.predict(candidates) |
|
191
|
|
|
|
|
192
|
|
|
suggestions = [] |
|
193
|
|
|
for doc_predictions in predictions: |
|
194
|
|
|
vector = np.zeros(len(self.project.subjects), dtype=np.float32) |
|
195
|
|
|
for row in doc_predictions.iter_rows(named=True): |
|
196
|
|
|
position = self.project.subjects._uri_idx.get(row["label_id"], 0) |
|
197
|
|
|
vector[position] = row["score"] |
|
198
|
|
|
suggestions.append(vector_to_suggestions(vector, int(params["limit"]))) |
|
199
|
|
|
|
|
200
|
|
|
return SuggestionBatch.from_sequence( |
|
201
|
|
|
suggestions, |
|
202
|
|
|
self.project.subjects, |
|
203
|
|
|
limit=int(params.get("limit")), |
|
204
|
|
|
) |
|
205
|
|
|
|