Passed
Pull Request — main (#859)
by Juho
04:03 queued 28s
created

annif.backend.llm_ensemble   A

Complexity

Total Complexity 32

Size/Duplication

Total Lines 347
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 235
dl 0
loc 347
rs 9.84
c 0
b 0
f 0
wmc 32

13 Methods

Rating   Name   Duplication   Size   Complexity  
A BaseLLMBackend._truncate_text() 0 5 1
A BaseLLMBackend.default_params() 0 5 1
B BaseLLMBackend.initialize() 0 35 6
A LLMEnsembleBackend._get_labels_batch() 0 9 1
A BaseLLMBackend._call_llm() 0 30 2
A LLMEnsembleBackend.get_hp_optimizer() 0 2 1
A LLMEnsembleBackend._suggest_batch() 0 24 4
A BaseLLMBackend._get_tokenizer() 0 8 2
B LLMEnsembleBackend._llm_suggest_batch() 0 63 5
A BaseLLMBackend._verify_connection() 0 11 2
A LLMEnsembleOptimizer._postprocess() 0 7 1
A LLMEnsembleOptimizer._objective() 0 24 2
A LLMEnsembleOptimizer._prepare() 0 50 4
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError, OpenAI, OpenAIError
12
from transformers import AutoTokenizer
13
14
import annif.eval
15
import annif.parallel
16
import annif.util
17
from annif.exception import ConfigurationException, OperationFailedException
18
from annif.suggestion import SubjectSuggestion, SuggestionBatch
19
20
from . import backend, ensemble, hyperopt
21
22
if TYPE_CHECKING:
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class BaseLLMBackend(backend.AnnifBackend):
27
    """Base class for LLM backends"""
28
29
    _client = None
30
31
    DEFAULT_PARAMETERS = {
32
        "api_version": "2024-10-21",
33
        "temperature": 0.0,
34
        "top_p": 1.0,
35
        "seed": 0,
36
        "max_completion_tokens": 2000,
37
    }
38
39
    def initialize(self, parallel: bool = False) -> None:
40
        if self._client is not None:
41
            return
42
43
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
44
        api_base_url = os.getenv("LLM_API_BASE_URL")
45
        try:
46
            self.model = self.params["model"]
47
        except KeyError as err:
48
            raise ConfigurationException(
49
                "model setting is missing", project_id=self.project.project_id
50
            )
51
52
        if api_base_url is not None:
53
            self._client = OpenAI(
54
                base_url=api_base_url,
55
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
56
            )
57
        elif azure_endpoint is not None:
58
            self._client = AzureOpenAI(
59
                azure_endpoint=azure_endpoint,
60
                api_key=os.getenv("AZURE_OPENAI_KEY"),
61
                api_version=self.params["api_version"],
62
            )
63
        else:
64
            raise OperationFailedException(
65
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
66
                "environment variable for LLM API access."
67
            )
68
69
        # Tokenizer is unnecessary if truncation is not performed
70
        if int(self.params["max_prompt_tokens"]) > 0:
71
            self.tokenizer = self._get_tokenizer()
72
        self._verify_connection()
73
        super().initialize(parallel)
74
75
    def _get_tokenizer(self):
76
        try:
77
            # Try OpenAI tokenizer
78
            base_model = self.model.rsplit("-", 1)[0]
79
            return tiktoken.encoding_for_model(base_model)
80
        except KeyError:
81
            # Fallback to Hugging Face tokenizer
82
            return AutoTokenizer.from_pretrained(self.model)
83
84
    def _verify_connection(self):
85
        try:
86
            self._call_llm(
87
                system_prompt="You are a helpful assistant.",
88
                prompt="This is a test prompt to verify the connection.",
89
                params=self.params,
90
            )
91
        except OpenAIError as err:
92
            raise OperationFailedException(
93
                f"Failed to connect to LLM API: {err}"
94
            ) from err
95
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
96
97
    def default_params(self):
98
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
99
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
100
        params.update(self.DEFAULT_PARAMETERS)
101
        return params
102
103
    def _truncate_text(self, text, max_prompt_tokens):
104
        """Truncate text so it contains at most max_prompt_tokens according to the
105
        OpenAI tokenizer"""
106
        tokens = self.tokenizer.encode(text)
107
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
108
109
    def _call_llm(
110
        self,
111
        system_prompt: str,
112
        prompt: str,
113
        params: dict[str, Any],
114
        response_format: Optional[dict] = None,
115
    ) -> str:
116
        temperature = float(params["temperature"])
117
        top_p = float(params["top_p"])
118
        seed = int(params["seed"])
119
        max_completion_tokens = int(params["max_completion_tokens"])
120
121
        messages = [
122
            {"role": "system", "content": system_prompt},
123
            {"role": "user", "content": prompt},
124
        ]
125
        try:
126
            completion = self._client.chat.completions.create(
127
                model=self.model,
128
                messages=messages,
129
                temperature=temperature,
130
                seed=seed,
131
                top_p=top_p,
132
                max_completion_tokens=max_completion_tokens,
133
                response_format=response_format,
134
            )
135
        except OpenAIError as err:
136
            print(err)
137
            return "{}"
138
        return completion.choices[0].message.content
139
140
141
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
142
    """Ensemble backend that combines results from multiple projects and scores them
143
    with a LLM"""
144
145
    name = "llm_ensemble"
146
147
    DEFAULT_PARAMETERS = {
148
        "max_prompt_tokens": 0,
149
        "llm_weight": 0.7,
150
        "llm_exponent": 1.0,
151
        "labels_language": "en",
152
        "sources_limit": 10,
153
    }
154
155
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
156
        return LLMEnsembleOptimizer(self, corpus, metric)
157
158
    def _suggest_batch(
159
        self, texts: list[str], params: dict[str, Any]
160
    ) -> SuggestionBatch:
161
        sources = annif.util.parse_sources(params["sources"])
162
        llm_weight = float(params["llm_weight"])
163
        llm_exponent = float(params["llm_exponent"])
164
        if llm_weight < 0.0 or llm_weight > 1.0:
165
            raise ValueError("llm_weight must be between 0.0 and 1.0")
166
        if llm_exponent < 0.0:
167
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
168
169
        batch_by_source = self._suggest_with_sources(texts, sources)
170
        merged_source_batch = self._merge_source_batches(
171
            batch_by_source, sources, {"limit": params["sources_limit"]}
172
        )
173
174
        # Score the suggestion labels with the LLM
175
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
176
177
        batches = [merged_source_batch, llm_results_batch]
178
        weights = [1.0 - llm_weight, llm_weight]
179
        exponents = [1.0, llm_exponent]
180
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
181
            limit=int(params["limit"])
182
        )
183
184
    def _llm_suggest_batch(
185
        self,
186
        texts: list[str],
187
        suggestion_batch: SuggestionBatch,
188
        params: dict[str, Any],
189
    ) -> SuggestionBatch:
190
191
        max_prompt_tokens = int(params["max_prompt_tokens"])
192
193
        system_prompt = """
194
            You will be given text and a list of keywords to describe it. Your task is
195
            to score the keywords with a value between 0 and 100. The score value
196
            should depend on how well the keyword represents the text: a perfect
197
            keyword should have score 100 and completely unrelated keyword score
198
            0. You must output JSON with keywords as field names and add their scores
199
            as field values.
200
            There must be the same number of objects in the JSON as there are lines in
201
            the intput keyword list; do not skip scoring any keywords.
202
        """
203
204
        labels_batch = self._get_labels_batch(suggestion_batch)
205
206
        def process_single_prompt(text, labels):
207
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
208
            if max_prompt_tokens > 0:
209
                text = self._truncate_text(text, max_prompt_tokens)
210
            prompt += "Here is the text:\n" + text + "\n"
211
212
            response = self._call_llm(
213
                system_prompt,
214
                prompt,
215
                params,
216
                response_format={"type": "json_object"},
217
            )
218
            try:
219
                llm_result = json.loads(response)
220
            except (TypeError, json.decoder.JSONDecodeError) as err:
221
                print(f"Error decoding JSON response from LLM: '{response[:100]}...'")
222
                print(f"{str(err)}")
223
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
224
225
            return [
226
                (
227
                    SubjectSuggestion(
228
                        subject_id=self.project.subjects.by_label(
229
                            llm_label, self.params["labels_language"]
230
                        ),
231
                        score=score / 100.0,  # LLM scores are between 0 and 100
232
                    )
233
                    if llm_label in labels
234
                    else SubjectSuggestion(subject_id=None, score=0.0)
235
                )
236
                for llm_label, score in llm_result.items()
237
            ]
238
239
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
240
            llm_batch_suggestions = list(
241
                executor.map(process_single_prompt, texts, labels_batch)
242
            )
243
244
        return SuggestionBatch.from_sequence(
245
            llm_batch_suggestions,
246
            self.project.subjects,
247
        )
248
249
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
250
        return [
251
            [
252
                self.project.subjects[suggestion.subject_id].labels[
253
                    self.params["labels_language"]
254
                ]
255
                for suggestion in suggestion_result
256
            ]
257
            for suggestion_result in suggestion_batch
258
        ]
259
260
261
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
262
    """Hyperparameter optimizer for the LLM ensemble backend"""
263
264
    def _prepare(self, n_jobs=1):
265
        sources = annif.util.parse_sources(self._backend.params["sources"])
266
        project_ids = [source[0] for source in sources]
267
268
        # initialize the source projects before forking, to save memory
269
        # for project_id in sources.keys():
270
        #     project = self._backend.project.registry.get_project(project_id)
271
        #     project.initialize(parallel=True)
272
        self._backend.initialize(parallel=True)
273
274
        psmap = annif.parallel.ProjectSuggestMap(
275
            self._backend.project.registry,
276
            project_ids,
277
            backend_params=None,
278
            limit=None,
279
            threshold=0.0,
280
        )
281
282
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
283
284
        self._gold_batches = []
285
        self._source_batches = []
286
287
        print("Generating source batches")
288
        with pool_class(jobs) as pool:
289
            for suggestions_batch, gold_batch in pool.imap_unordered(
290
                psmap.suggest_batch, self._corpus.doc_batches
291
            ):
292
                self._source_batches.append(suggestions_batch)
293
                self._gold_batches.append(gold_batch)
294
295
        # get the llm batches
296
        print("Generating LLM batches")
297
        self._merged_source_batches = []
298
        self._llm_batches = []
299
        for batch_by_source, docs_batch in zip(
300
            self._source_batches, self._corpus.doc_batches
301
        ):
302
            merged_source_batch = self._backend._merge_source_batches(
303
                batch_by_source,
304
                sources,
305
                {"limit": self._backend.params["sources_limit"]},
306
            )
307
            llm_batch = self._backend._llm_suggest_batch(
308
                [doc.text for doc in docs_batch],
309
                merged_source_batch,
310
                self._backend.params,
311
            )
312
            self._merged_source_batches.append(merged_source_batch)
313
            self._llm_batches.append(llm_batch)
314
315
    def _objective(self, trial) -> float:
316
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
317
        params = {
318
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
319
            "llm_exponent": trial.suggest_float("llm_exponent", 0.25, 10.0, log=True),
320
        }
321
        for merged_source_batch, llm_batch, gold_batch in zip(
322
            self._merged_source_batches, self._llm_batches, self._gold_batches
323
        ):
324
            batches = [merged_source_batch, llm_batch]
325
            weights = [
326
                1.0 - params["llm_weight"],
327
                params["llm_weight"],
328
            ]
329
            exponents = [
330
                1.0,
331
                params["llm_exponent"],
332
            ]
333
            avg_batch = SuggestionBatch.from_averaged(
334
                batches, weights, exponents
335
            ).filter(limit=int(self._backend.params["limit"]))
336
            eval_batch.evaluate_many(avg_batch, gold_batch)
337
        results = eval_batch.results(metrics=[self._metric])
338
        return results[self._metric]
339
340
    def _postprocess(self, study):
341
        bp = study.best_params
342
        lines = [
343
            f"llm_weight={bp['llm_weight']}",
344
            f"llm_exponent={bp['llm_exponent']}",
345
        ]
346
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
347