Passed
Push — llms4subjects-natlibfi-germeva... ( 10211c...5ca761 )
by Juho
03:28
created

annif.backend.llm_ensemble   A

Complexity

Total Complexity 29

Size/Duplication

Total Lines 336
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 226
dl 0
loc 336
rs 10
c 0
b 0
f 0
wmc 29

13 Methods

Rating   Name   Duplication   Size   Complexity  
A LLMEnsembleOptimizer._postprocess() 0 7 1
A BaseLLMBackend._truncate_text() 0 5 1
A BaseLLMBackend.default_params() 0 5 1
A LLMEnsembleOptimizer._objective() 0 24 2
A BaseLLMBackend.initialize() 0 31 4
A LLMEnsembleBackend._get_labels_batch() 0 9 1
A LLMEnsembleOptimizer._prepare() 0 49 4
A BaseLLMBackend._call_llm() 0 28 2
A LLMEnsembleBackend.get_hp_optimizer() 0 2 1
A LLMEnsembleBackend._suggest_batch() 0 24 4
A BaseLLMBackend._get_tokenizer() 0 8 2
B LLMEnsembleBackend._llm_suggest_batch() 0 62 4
A BaseLLMBackend._verify_connection() 0 11 2
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError, OpenAI, OpenAIError
12
from transformers import AutoTokenizer
13
14
import annif.eval
15
import annif.parallel
16
import annif.util
17
from annif.exception import ConfigurationException, OperationFailedException
18
from annif.suggestion import SubjectSuggestion, SuggestionBatch
19
20
from . import backend, ensemble, hyperopt
21
22
if TYPE_CHECKING:
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class BaseLLMBackend(backend.AnnifBackend):
27
    """Base class for LLM backends"""
28
29
    DEFAULT_PARAMETERS = {
30
        "api_version": "2024-10-21",
31
        "temperature": 0.0,
32
        "top_p": 1.0,
33
        "seed": 0,
34
    }
35
36
    def initialize(self, parallel: bool = False) -> None:
37
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
38
        api_base_url = os.getenv("LLM_API_BASE_URL")
39
40
        try:
41
            self.model = self.params["model"]
42
        except KeyError as err:
43
            raise ConfigurationException(
44
                "model setting is missing", project_id=self.project.project_id
45
            )
46
47
        if api_base_url is not None:
48
            self.client = OpenAI(
49
                base_url=api_base_url,
50
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
51
            )
52
        elif azure_endpoint is not None:
53
            self.client = AzureOpenAI(
54
                azure_endpoint=azure_endpoint,
55
                api_key=os.getenv("AZURE_OPENAI_KEY"),
56
                api_version=self.params["api_version"],
57
            )
58
        else:
59
            raise OperationFailedException(
60
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
61
                "environment variable for LLM API access."
62
            )
63
64
        self.tokenizer = self._get_tokenizer()
65
        self._verify_connection()
66
        super().initialize(parallel)
67
68
    def _get_tokenizer(self):
69
        try:
70
            # Try OpenAI tokenizer
71
            base_model = self.model.rsplit("-", 1)[0]
72
            return tiktoken.encoding_for_model(base_model)
73
        except KeyError:
74
            # Fallback to Hugging Face tokenizer
75
            return AutoTokenizer.from_pretrained(self.model)
76
77
    def _verify_connection(self):
78
        try:
79
            self._call_llm(
80
                system_prompt="You are a helpful assistant.",
81
                prompt="This is a test prompt to verify the connection.",
82
                params=self.params,
83
            )
84
        except OpenAIError as err:
85
            raise OperationFailedException(
86
                f"Failed to connect to LLM API: {err}"
87
            ) from err
88
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
89
90
    def default_params(self):
91
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
92
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
93
        params.update(self.DEFAULT_PARAMETERS)
94
        return params
95
96
    def _truncate_text(self, text, max_prompt_tokens):
97
        """Truncate text so it contains at most max_prompt_tokens according to the
98
        OpenAI tokenizer"""
99
        tokens = self.tokenizer.encode(text)
100
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
101
102
    def _call_llm(
103
        self,
104
        system_prompt: str,
105
        prompt: str,
106
        params: dict[str, Any],
107
        response_format: Optional[dict] = None,
108
    ) -> str:
109
        temperature = float(params["temperature"])
110
        top_p = float(params["top_p"])
111
        seed = int(params["seed"])
112
113
        messages = [
114
            {"role": "system", "content": system_prompt},
115
            {"role": "user", "content": prompt},
116
        ]
117
        try:
118
            completion = self.client.chat.completions.create(
119
                model=self.model,
120
                messages=messages,
121
                temperature=temperature,
122
                seed=seed,
123
                top_p=top_p,
124
                response_format=response_format,
125
            )
126
        except BadRequestError as err:
127
            print(err)
128
            return "{}"
129
        return completion.choices[0].message.content
130
131
132
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
133
    """Ensemble backend that combines results from multiple projects and scores them
134
    with a LLM"""
135
136
    name = "llm_ensemble"
137
138
    DEFAULT_PARAMETERS = {
139
        "max_prompt_tokens": 127000,
140
        "llm_weight": 0.7,
141
        "llm_exponent": 1.0,
142
        "labels_language": "en",
143
        "sources_limit": 10,
144
    }
145
146
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
147
        return LLMEnsembleOptimizer(self, corpus, metric)
148
149
    def _suggest_batch(
150
        self, texts: list[str], params: dict[str, Any]
151
    ) -> SuggestionBatch:
152
        sources = annif.util.parse_sources(params["sources"])
153
        llm_weight = float(params["llm_weight"])
154
        llm_exponent = float(params["llm_exponent"])
155
        if llm_weight < 0.0 or llm_weight > 1.0:
156
            raise ValueError("llm_weight must be between 0.0 and 1.0")
157
        if llm_exponent < 0.0:
158
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
159
160
        batch_by_source = self._suggest_with_sources(texts, sources)
161
        merged_source_batch = self._merge_source_batches(
162
            batch_by_source, sources, {"limit": params["sources_limit"]}
163
        )
164
165
        # Score the suggestion labels with the LLM
166
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
167
168
        batches = [merged_source_batch, llm_results_batch]
169
        weights = [1.0 - llm_weight, llm_weight]
170
        exponents = [1.0, llm_exponent]
171
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
172
            limit=int(params["limit"])
173
        )
174
175
    def _llm_suggest_batch(
176
        self,
177
        texts: list[str],
178
        suggestion_batch: SuggestionBatch,
179
        params: dict[str, Any],
180
    ) -> SuggestionBatch:
181
182
        max_prompt_tokens = int(params["max_prompt_tokens"])
183
184
        system_prompt = """
185
            You will be given text and a list of keywords to describe it. Your task is
186
            to score the keywords with a value between 0 and 100. The score value
187
            should depend on how well the keyword represents the text: a perfect
188
            keyword should have score 100 and completely unrelated keyword score
189
            0. You must output JSON with keywords as field names and add their scores
190
            as field values.
191
            There must be the same number of objects in the JSON as there are lines in
192
            the intput keyword list; do not skip scoring any keywords.
193
        """
194
195
        labels_batch = self._get_labels_batch(suggestion_batch)
196
197
        def process_single_prompt(text, labels):
198
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
199
            text = self._truncate_text(text, max_prompt_tokens)
200
            prompt += "Here is the text:\n" + text + "\n"
201
202
            response = self._call_llm(
203
                system_prompt,
204
                prompt,
205
                params,
206
                response_format={"type": "json_object"},
207
            )
208
            try:
209
                llm_result = json.loads(response)
210
            except (TypeError, json.decoder.JSONDecodeError) as err:
211
                print(f"Error decoding JSON response from LLM: {response}")
212
                print(f"Error: {err}")
213
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
214
215
            return [
216
                (
217
                    SubjectSuggestion(
218
                        subject_id=self.project.subjects.by_label(
219
                            llm_label, self.params["labels_language"]
220
                        ),
221
                        score=score / 100.0,  # LLM scores are between 0 and 100
222
                    )
223
                    if llm_label in labels
224
                    else SubjectSuggestion(subject_id=None, score=0.0)
225
                )
226
                for llm_label, score in llm_result.items()
227
            ]
228
229
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
230
            llm_batch_suggestions = list(
231
                executor.map(process_single_prompt, texts, labels_batch)
232
            )
233
234
        return SuggestionBatch.from_sequence(
235
            llm_batch_suggestions,
236
            self.project.subjects,
237
        )
238
239
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
240
        return [
241
            [
242
                self.project.subjects[suggestion.subject_id].labels[
243
                    self.params["labels_language"]
244
                ]
245
                for suggestion in suggestion_result
246
            ]
247
            for suggestion_result in suggestion_batch
248
        ]
249
250
251
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
252
    """Hyperparameter optimizer for the LLM ensemble backend"""
253
254
    def _prepare(self, n_jobs=1):
255
        sources = dict(annif.util.parse_sources(self._backend.params["sources"]))
256
257
        # initialize the source projects before forking, to save memory
258
        # for project_id in sources.keys():
259
        #     project = self._backend.project.registry.get_project(project_id)
260
        #     project.initialize(parallel=True)
261
        self._backend.initialize(parallel=True)
262
263
        psmap = annif.parallel.ProjectSuggestMap(
264
            self._backend.project.registry,
265
            list(sources.keys()),
266
            backend_params=None,
267
            limit=None,
268
            threshold=0.0,
269
        )
270
271
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
272
273
        self._gold_batches = []
274
        self._source_batches = []
275
276
        print("Generating source batches")
277
        with pool_class(jobs) as pool:
278
            for suggestions_batch, gold_batch in pool.imap_unordered(
279
                psmap.suggest_batch, self._corpus.doc_batches
280
            ):
281
                self._source_batches.append(suggestions_batch)
282
                self._gold_batches.append(gold_batch)
283
284
        # get the llm batches
285
        print("Generating LLM batches")
286
        self._merged_source_batches = []
287
        self._llm_batches = []
288
        for batch_by_source, docs_batch in zip(
289
            self._source_batches, self._corpus.doc_batches
290
        ):
291
            merged_source_batch = self._backend._merge_source_batches(
292
                batch_by_source,
293
                sources.items(),
294
                {"limit": self._backend.params["sources_limit"]},
295
            )
296
            llm_batch = self._backend._llm_suggest_batch(
297
                [doc.text for doc in docs_batch],
298
                merged_source_batch,
299
                self._backend.params,
300
            )
301
            self._merged_source_batches.append(merged_source_batch)
302
            self._llm_batches.append(llm_batch)
303
304
    def _objective(self, trial) -> float:
305
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
306
        params = {
307
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
308
            "llm_exponent": trial.suggest_float("llm_exponent", 0.0, 30.0),
309
        }
310
        for merged_source_batch, llm_batch, gold_batch in zip(
311
            self._merged_source_batches, self._llm_batches, self._gold_batches
312
        ):
313
            batches = [merged_source_batch, llm_batch]
314
            weights = [
315
                1.0 - params["llm_weight"],
316
                params["llm_weight"],
317
            ]
318
            exponents = [
319
                1.0,
320
                params["llm_exponent"],
321
            ]
322
            avg_batch = SuggestionBatch.from_averaged(
323
                batches, weights, exponents
324
            ).filter(limit=int(self._backend.params["limit"]))
325
            eval_batch.evaluate_many(avg_batch, gold_batch)
326
        results = eval_batch.results(metrics=[self._metric])
327
        return results[self._metric]
328
329
    def _postprocess(self, study):
330
        bp = study.best_params
331
        lines = [
332
            f"llm_weight={bp['llm_weight']}",
333
            f"llm_exponent={bp['llm_exponent']}",
334
        ]
335
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
336