Passed
Pull Request — main (#859)
by Juho
08:15 queued 05:06
created

annif.backend.llm_ensemble   A

Complexity

Total Complexity 32

Size/Duplication

Total Lines 342
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 235
dl 0
loc 342
rs 9.84
c 0
b 0
f 0
wmc 32

13 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseLLMBackend._truncate_text() 0 5 1
A BaseLLMBackend.default_params() 0 5 1
B BaseLLMBackend.initialize() 0 35 6
A LLMEnsembleBackend._get_labels_batch() 0 9 1
A BaseLLMBackend._call_llm() 0 30 2
A LLMEnsembleBackend.get_hp_optimizer() 0 2 1
A LLMEnsembleBackend._suggest_batch() 0 24 4
A BaseLLMBackend._get_tokenizer() 0 8 2
B LLMEnsembleBackend._llm_suggest_batch() 0 63 5
A BaseLLMBackend._verify_connection() 0 11 2
A LLMEnsembleOptimizer._postprocess() 0 7 1
A LLMEnsembleOptimizer._objective() 0 24 2
A LLMEnsembleOptimizer._prepare() 0 45 4
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, OpenAI, OpenAIError
12
from transformers import AutoTokenizer
13
14
import annif.eval
15
import annif.parallel
16
import annif.util
17
from annif.exception import ConfigurationException, OperationFailedException
18
from annif.suggestion import SubjectSuggestion, SuggestionBatch
19
20
from . import backend, ensemble, hyperopt
21
22
if TYPE_CHECKING:
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class BaseLLMBackend(backend.AnnifBackend):
27
    """Base class for LLM backends"""
28
29
    _client = None
30
31
    DEFAULT_PARAMETERS = {
32
        "api_version": "2024-10-21",
33
        "temperature": 0.0,
34
        "top_p": 1.0,
35
        "seed": 0,
36
        "max_completion_tokens": 2000,
37
    }
38
39
    def initialize(self, parallel: bool = False) -> None:
40
        if self._client is not None:
41
            return
42
43
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
44
        api_base_url = os.getenv("LLM_API_BASE_URL")
45
        try:
46
            self.model = self.params["model"]
47
        except KeyError as err:
48
            raise ConfigurationException(
49
                "model setting is missing", project_id=self.project.project_id
50
            )
51
52
        if api_base_url is not None:
53
            self._client = OpenAI(
54
                base_url=api_base_url,
55
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
56
            )
57
        elif azure_endpoint is not None:
58
            self._client = AzureOpenAI(
59
                azure_endpoint=azure_endpoint,
60
                api_key=os.getenv("AZURE_OPENAI_KEY"),
61
                api_version=self.params["api_version"],
62
            )
63
        else:
64
            raise OperationFailedException(
65
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
66
                "environment variable for LLM API access."
67
            )
68
69
        # Tokenizer is unnecessary if truncation is not performed
70
        if int(self.params["max_prompt_tokens"]) > 0:
71
            self.tokenizer = self._get_tokenizer()
72
        self._verify_connection()
73
        super().initialize(parallel)
74
75
    def _get_tokenizer(self):
76
        try:
77
            # Try OpenAI tokenizer
78
            base_model = self.model.rsplit("-", 1)[0]
79
            return tiktoken.encoding_for_model(base_model)
80
        except KeyError:
81
            # Fallback to Hugging Face tokenizer
82
            return AutoTokenizer.from_pretrained(self.model)
83
84
    def _verify_connection(self):
85
        try:
86
            self._call_llm(
87
                system_prompt="You are a helpful assistant.",
88
                prompt="This is a test prompt to verify the connection.",
89
                params=self.params,
90
            )
91
        except OpenAIError as err:
92
            raise OperationFailedException(
93
                f"Failed to connect to LLM API: {err}"
94
            ) from err
95
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
96
97
    def default_params(self):
98
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
99
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
100
        params.update(self.DEFAULT_PARAMETERS)
101
        return params
102
103
    def _truncate_text(self, text, max_prompt_tokens):
104
        """Truncate text so it contains at most max_prompt_tokens according to the
105
        OpenAI tokenizer"""
106
        tokens = self.tokenizer.encode(text)
107
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
108
109
    def _call_llm(
110
        self,
111
        system_prompt: str,
112
        prompt: str,
113
        params: dict[str, Any],
114
        response_format: Optional[dict] = None,
115
    ) -> str:
116
        temperature = float(params["temperature"])
117
        top_p = float(params["top_p"])
118
        seed = int(params["seed"])
119
        max_completion_tokens = int(params["max_completion_tokens"])
120
121
        messages = [
122
            {"role": "system", "content": system_prompt},
123
            {"role": "user", "content": prompt},
124
        ]
125
        try:
126
            completion = self._client.chat.completions.create(
127
                model=self.model,
128
                messages=messages,
129
                temperature=temperature,
130
                seed=seed,
131
                top_p=top_p,
132
                max_completion_tokens=max_completion_tokens,
133
                response_format=response_format,
134
            )
135
        except OpenAIError as err:
136
            print(err)
137
            return "{}"
138
        return completion.choices[0].message.content
139
140
141
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
142
    """Ensemble backend that combines results from multiple projects and scores them
143
    with a LLM"""
144
145
    name = "llm_ensemble"
146
147
    DEFAULT_PARAMETERS = {
148
        "max_prompt_tokens": 0,
149
        "llm_weight": 0.7,
150
        "llm_exponent": 1.0,
151
        "labels_language": "en",
152
        "sources_limit": 10,
153
    }
154
155
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
156
        return LLMEnsembleOptimizer(self, corpus, metric)
157
158
    def _suggest_batch(
159
        self, texts: list[str], params: dict[str, Any]
160
    ) -> SuggestionBatch:
161
        sources = annif.util.parse_sources(params["sources"])
162
        llm_weight = float(params["llm_weight"])
163
        llm_exponent = float(params["llm_exponent"])
164
        if llm_weight < 0.0 or llm_weight > 1.0:
165
            raise ValueError("llm_weight must be between 0.0 and 1.0")
166
        if llm_exponent < 0.0:
167
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
168
169
        batch_by_source = self._suggest_with_sources(texts, sources)
170
        merged_source_batch = self._merge_source_batches(
171
            batch_by_source, sources, {"limit": params["sources_limit"]}
172
        )
173
174
        # Score the suggestion labels with the LLM
175
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
176
177
        batches = [merged_source_batch, llm_results_batch]
178
        weights = [1.0 - llm_weight, llm_weight]
179
        exponents = [1.0, llm_exponent]
180
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
181
            limit=int(params["limit"])
182
        )
183
184
    def _llm_suggest_batch(
185
        self,
186
        texts: list[str],
187
        suggestion_batch: SuggestionBatch,
188
        params: dict[str, Any],
189
    ) -> SuggestionBatch:
190
191
        max_prompt_tokens = int(params["max_prompt_tokens"])
192
193
        system_prompt = """
194
            You will be given text and a list of keywords to describe it. Your task is
195
            to score the keywords with a value between 0 and 100. The score value
196
            should depend on how well the keyword represents the text: a perfect
197
            keyword should have score 100 and completely unrelated keyword score
198
            0. You must output JSON with keywords as field names and add their scores
199
            as field values.
200
            There must be the same number of objects in the JSON as there are lines in
201
            the intput keyword list; do not skip scoring any keywords.
202
        """
203
204
        labels_batch = self._get_labels_batch(suggestion_batch)
205
206
        def process_single_prompt(text, labels):
207
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
208
            if max_prompt_tokens > 0:
209
                text = self._truncate_text(text, max_prompt_tokens)
210
            prompt += "Here is the text:\n" + text + "\n"
211
212
            response = self._call_llm(
213
                system_prompt,
214
                prompt,
215
                params,
216
                response_format={"type": "json_object"},
217
            )
218
            try:
219
                llm_result = json.loads(response)
220
            except (TypeError, json.decoder.JSONDecodeError) as err:
221
                print(f"Error decoding JSON response from LLM: '{response[:100]}...'")
222
                print(f"{str(err)}")
223
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
224
225
            return [
226
                (
227
                    SubjectSuggestion(
228
                        subject_id=self.project.subjects.by_label(
229
                            llm_label, self.params["labels_language"]
230
                        ),
231
                        score=score / 100.0,  # LLM scores are between 0 and 100
232
                    )
233
                    if llm_label in labels
234
                    else SubjectSuggestion(subject_id=None, score=0.0)
235
                )
236
                for llm_label, score in llm_result.items()
237
            ]
238
239
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
240
            llm_batch_suggestions = list(
241
                executor.map(process_single_prompt, texts, labels_batch)
242
            )
243
244
        return SuggestionBatch.from_sequence(
245
            llm_batch_suggestions,
246
            self.project.subjects,
247
        )
248
249
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
250
        return [
251
            [
252
                self.project.subjects[suggestion.subject_id].labels[
253
                    self.params["labels_language"]
254
                ]
255
                for suggestion in suggestion_result
256
            ]
257
            for suggestion_result in suggestion_batch
258
        ]
259
260
261
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
262
    """Hyperparameter optimizer for the LLM ensemble backend"""
263
264
    def _prepare(self, n_jobs=1):
265
        sources = annif.util.parse_sources(self._backend.params["sources"])
266
        project_ids = [source[0] for source in sources]
267
        self._backend.initialize(parallel=True)
268
269
        psmap = annif.parallel.ProjectSuggestMap(
270
            self._backend.project.registry,
271
            project_ids,
272
            backend_params=None,
273
            limit=None,
274
            threshold=0.0,
275
        )
276
277
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
278
279
        self._gold_batches = []
280
        self._source_batches = []
281
282
        print("Generating source batches")
283
        with pool_class(jobs) as pool:
284
            for suggestions_batch, gold_batch in pool.imap_unordered(
285
                psmap.suggest_batch, self._corpus.doc_batches
286
            ):
287
                self._source_batches.append(suggestions_batch)
288
                self._gold_batches.append(gold_batch)
289
290
        # get the llm batches
291
        print("Generating LLM batches")
292
        self._merged_source_batches = []
293
        self._llm_batches = []
294
        for batch_by_source, docs_batch in zip(
295
            self._source_batches, self._corpus.doc_batches
296
        ):
297
            merged_source_batch = self._backend._merge_source_batches(
298
                batch_by_source,
299
                sources,
300
                {"limit": self._backend.params["sources_limit"]},
301
            )
302
            llm_batch = self._backend._llm_suggest_batch(
303
                [doc.text for doc in docs_batch],
304
                merged_source_batch,
305
                self._backend.params,
306
            )
307
            self._merged_source_batches.append(merged_source_batch)
308
            self._llm_batches.append(llm_batch)
309
310
    def _objective(self, trial) -> float:
311
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
312
        params = {
313
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
314
            "llm_exponent": trial.suggest_float("llm_exponent", 0.25, 10.0, log=True),
315
        }
316
        for merged_source_batch, llm_batch, gold_batch in zip(
317
            self._merged_source_batches, self._llm_batches, self._gold_batches
318
        ):
319
            batches = [merged_source_batch, llm_batch]
320
            weights = [
321
                1.0 - params["llm_weight"],
322
                params["llm_weight"],
323
            ]
324
            exponents = [
325
                1.0,
326
                params["llm_exponent"],
327
            ]
328
            avg_batch = SuggestionBatch.from_averaged(
329
                batches, weights, exponents
330
            ).filter(limit=int(self._backend.params["limit"]))
331
            eval_batch.evaluate_many(avg_batch, gold_batch)
332
        results = eval_batch.results(metrics=[self._metric])
333
        return results[self._metric]
334
335
    def _postprocess(self, study):
336
        bp = study.best_params
337
        lines = [
338
            f"llm_weight={bp['llm_weight']}",
339
            f"llm_exponent={bp['llm_exponent']}",
340
        ]
341
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
342