Passed
Pull Request — main (#859)
by Juho
06:23 queued 03:10
created

LLMEnsembleOptimizer._objective()   A

Complexity

Conditions 2

Size

Total Lines 24
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 20
nop 2
dl 0
loc 24
rs 9.4
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
import textwrap
9
from typing import TYPE_CHECKING, Any, Optional
10
11
import tiktoken
12
from openai import AzureOpenAI, OpenAI, OpenAIError
13
from transformers import AutoTokenizer
14
15
import annif.eval
16
import annif.parallel
17
import annif.util
18
from annif.exception import ConfigurationException, OperationFailedException
19
from annif.suggestion import SubjectSuggestion, SuggestionBatch
20
21
from . import backend, ensemble, hyperopt
22
23
if TYPE_CHECKING:
24
    from annif.corpus.document import DocumentCorpus
25
26
27
class BaseLLMBackend(backend.AnnifBackend):
28
    """Base class for LLM backends"""
29
30
    _client = None
31
32
    DEFAULT_PARAMETERS = {
33
        "api_version": "2024-10-21",
34
        "temperature": 0.0,
35
        "top_p": 1.0,
36
        "seed": 0,
37
        "max_completion_tokens": 2000,
38
    }
39
40
    def initialize(self, parallel: bool = False) -> None:
41
        if self._client is not None:
42
            return
43
44
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
45
        api_base_url = os.getenv("LLM_API_BASE_URL")
46
        try:
47
            self.model = self.params["model"]
48
        except KeyError:
49
            raise ConfigurationException(
50
                "model setting is missing", project_id=self.project.project_id
51
            )
52
53
        if api_base_url is not None:
54
            self._client = OpenAI(
55
                base_url=api_base_url,
56
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
57
            )
58
        elif azure_endpoint is not None:
59
            self._client = AzureOpenAI(
60
                azure_endpoint=azure_endpoint,
61
                api_key=os.getenv("AZURE_OPENAI_KEY"),
62
                api_version=self.params["api_version"],
63
            )
64
        else:
65
            raise OperationFailedException(
66
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
67
                "environment variable for LLM API access."
68
            )
69
70
        # Tokenizer is unnecessary if truncation is not performed
71
        if int(self.params["max_prompt_tokens"]) > 0:
72
            self.tokenizer = self._get_tokenizer()
73
        self._verify_connection()
74
        super().initialize(parallel)
75
76
    def _get_tokenizer(self):
77
        try:
78
            # Try OpenAI tokenizer
79
            base_model = self.model.rsplit("-", 1)[0]
80
            return tiktoken.encoding_for_model(base_model)
81
        except KeyError:
82
            # Fallback to Hugging Face tokenizer
83
            return AutoTokenizer.from_pretrained(self.model)
84
85
    def _verify_connection(self):
86
        try:
87
            self._call_llm(
88
                system_prompt="You are a helpful assistant.",
89
                user_prompt="This is a test prompt to verify the connection.",
90
                params=self.params,
91
            )
92
        except OpenAIError as err:
93
            raise OperationFailedException(
94
                f"Failed to connect to LLM API: {err}"
95
            ) from err
96
        self.debug(f"connection successful to endpoint {self._client.base_url}")
97
98
    def default_params(self):
99
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
100
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
101
        params.update(self.DEFAULT_PARAMETERS)
102
        return params
103
104
    def _truncate_text(self, text, max_prompt_tokens):
105
        """Truncate text so it contains at most max_prompt_tokens according to the
106
        OpenAI tokenizer"""
107
        tokens = self.tokenizer.encode(text)
108
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
109
110
    def _call_llm(
111
        self,
112
        system_prompt: str,
113
        user_prompt: str,
114
        params: dict[str, Any],
115
        response_format: Optional[dict] = None,
116
    ) -> str:
117
        temperature = float(params["temperature"])
118
        top_p = float(params["top_p"])
119
        seed = int(params["seed"])
120
        max_completion_tokens = int(params["max_completion_tokens"])
121
122
        messages = [
123
            {"role": "system", "content": system_prompt},
124
            {"role": "user", "content": user_prompt},
125
        ]
126
        try:
127
            completion = self._client.chat.completions.create(
128
                model=self.model,
129
                messages=messages,
130
                temperature=temperature,
131
                seed=seed,
132
                top_p=top_p,
133
                max_completion_tokens=max_completion_tokens,
134
                response_format=response_format,
135
            )
136
        except OpenAIError as err:
137
            self.warning(
138
                "error calling LLM API; the base project scores are directly used. "
139
                f'API error: "{err}"'
140
            )
141
            return "{}"
142
        return completion.choices[0].message.content
143
144
145
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
146
    """Ensemble backend that combines results from multiple projects and scores them
147
    with a LLM"""
148
149
    name = "llm_ensemble"
150
151
    DEFAULT_PARAMETERS = {
152
        "max_prompt_tokens": 0,
153
        "llm_weight": 0.7,
154
        "llm_exponent": 1.0,
155
        "labels_language": "en",
156
        "sources_limit": 10,
157
        "max_workers": 32,
158
    }
159
160
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
161
        return LLMEnsembleOptimizer(self, corpus, metric)
162
163
    def _suggest_batch(
164
        self, texts: list[str], params: dict[str, Any]
165
    ) -> SuggestionBatch:
166
        sources = annif.util.parse_sources(params["sources"])
167
        llm_weight = float(params["llm_weight"])
168
        llm_exponent = float(params["llm_exponent"])
169
        if llm_weight < 0.0 or llm_weight > 1.0:
170
            raise ValueError("llm_weight must be between 0.0 and 1.0")
171
        if llm_exponent < 0.0:
172
            raise ValueError("llm_exponent must be greater than or equal to 0.0")
173
174
        batch_by_source = self._suggest_with_sources(texts, sources)
175
        merged_source_batch = self._merge_source_batches(
176
            batch_by_source, sources, {"limit": params["sources_limit"]}
177
        )
178
179
        # Score the suggestion labels with the LLM
180
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
181
182
        batches = [merged_source_batch, llm_results_batch]
183
        weights = [1.0 - llm_weight, llm_weight]
184
        exponents = [1.0, llm_exponent]
185
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
186
            limit=int(params["limit"])
187
        )
188
189
    def _llm_suggest_batch(
190
        self,
191
        texts: list[str],
192
        suggestion_batch: SuggestionBatch,
193
        params: dict[str, Any],
194
    ) -> SuggestionBatch:
195
196
        max_prompt_tokens = int(params["max_prompt_tokens"])
197
        max_workers = int(params["max_workers"])
198
199
        system_prompt = textwrap.dedent(
200
            """\
201
            You will be given text and a list of keywords to describe it. Your task is
202
            to score the keywords with a value between 0 and 100. The score value
203
            should depend on how well the keyword represents the text: a perfect
204
            keyword should have score 100 and completely unrelated keyword score 0.
205
206
            You must output JSON with keywords as field names and add their scores as
207
            field values.
208
209
            There must be the same number of objects in the JSON as there are lines in
210
            the input keyword list; do not skip scoring any keywords.
211
        """
212
        )
213
214
        labels_batch = self._get_labels_batch(suggestion_batch)
215
216
        def process_single_prompt(text, labels):
217
            user_prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n\n"
218
            if max_prompt_tokens > 0:
219
                text = self._truncate_text(text, max_prompt_tokens)
220
221
            user_prompt += "Here is the text:\n" + text
222
            self.debug(f'LLM system prompt: "{system_prompt}"')
223
            self.debug(f'LLM user prompt: "{user_prompt}"')
224
225
            response = self._call_llm(
226
                system_prompt,
227
                user_prompt,
228
                params,
229
                response_format={"type": "json_object"},
230
            )
231
            try:
232
                llm_result = json.loads(response)
233
            except json.JSONDecodeError as err:
234
                start = max(err.pos - 100, 0)
235
                end = err.pos + 101  # Slicing out of bounds is ok
236
                snippet = response[start:end]
237
                self.warning(
238
                    f"Failed to decode JSON response from LLM.\n"
239
                    f"Error: {err}\n"
240
                    f"Context (around error position {err.pos}):\n"
241
                    f"...{snippet}..."
242
                )
243
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
244
            except TypeError as err:
245
                self.warning(
246
                    f"Failed to decode JSON response from LLM due to TypeError: {err}\n"
247
                )
248
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
249
            return [
250
                (
251
                    SubjectSuggestion(
252
                        subject_id=self.project.subjects.by_label(
253
                            llm_label, self.params["labels_language"]
254
                        ),
255
                        score=score / 100.0,  # LLM scores are between 0 and 100
256
                    )
257
                    if llm_label in labels
258
                    else SubjectSuggestion(subject_id=None, score=0.0)
259
                )
260
                for llm_label, score in llm_result.items()
261
            ]
262
263
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
264
            llm_batch_suggestions = list(
265
                executor.map(process_single_prompt, texts, labels_batch)
266
            )
267
268
        return SuggestionBatch.from_sequence(
269
            llm_batch_suggestions,
270
            self.project.subjects,
271
        )
272
273
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
274
        return [
275
            [
276
                self.project.subjects[suggestion.subject_id].labels[
277
                    self.params["labels_language"]
278
                ]
279
                for suggestion in suggestion_result
280
            ]
281
            for suggestion_result in suggestion_batch
282
        ]
283
284
285
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
286
    """Hyperparameter optimizer for the LLM ensemble backend"""
287
288
    def _prepare(self, n_jobs=1):
289
        sources = annif.util.parse_sources(self._backend.params["sources"])
290
        project_ids = [source[0] for source in sources]
291
        self._backend.initialize(parallel=True)
292
293
        psmap = annif.parallel.ProjectSuggestMap(
294
            self._backend.project.registry,
295
            project_ids,
296
            backend_params=None,
297
            limit=None,
298
            threshold=0.0,
299
        )
300
301
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
302
303
        self._gold_batches = []
304
        self._source_batches = []
305
306
        self.debug("Generating source batches")
307
        with pool_class(jobs) as pool:
308
            for suggestions_batch, gold_batch in pool.imap_unordered(
309
                psmap.suggest_batch, self._corpus.doc_batches
310
            ):
311
                self._source_batches.append(suggestions_batch)
312
                self._gold_batches.append(gold_batch)
313
314
        # get the llm batches
315
        self.debug("Generating LLM batches")
316
        self._merged_source_batches = []
317
        self._llm_batches = []
318
        for batch_by_source, docs_batch in zip(
319
            self._source_batches, self._corpus.doc_batches
320
        ):
321
            merged_source_batch = self._backend._merge_source_batches(
322
                batch_by_source,
323
                sources,
324
                {"limit": self._backend.params["sources_limit"]},
325
            )
326
            llm_batch = self._backend._llm_suggest_batch(
327
                [doc.text for doc in docs_batch],
328
                merged_source_batch,
329
                self._backend.params,
330
            )
331
            self._merged_source_batches.append(merged_source_batch)
332
            self._llm_batches.append(llm_batch)
333
334
    def _objective(self, trial) -> float:
335
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
336
        params = {
337
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
338
            "llm_exponent": trial.suggest_float("llm_exponent", 0.25, 10.0, log=True),
339
        }
340
        for merged_source_batch, llm_batch, gold_batch in zip(
341
            self._merged_source_batches, self._llm_batches, self._gold_batches
342
        ):
343
            batches = [merged_source_batch, llm_batch]
344
            weights = [
345
                1.0 - params["llm_weight"],
346
                params["llm_weight"],
347
            ]
348
            exponents = [
349
                1.0,
350
                params["llm_exponent"],
351
            ]
352
            avg_batch = SuggestionBatch.from_averaged(
353
                batches, weights, exponents
354
            ).filter(limit=int(self._backend.params["limit"]))
355
            eval_batch.evaluate_many(avg_batch, gold_batch)
356
        results = eval_batch.results(metrics=[self._metric])
357
        return results[self._metric]
358
359
    def _postprocess(self, study):
360
        bp = study.best_params
361
        lines = [
362
            f"llm_weight={bp['llm_weight']}",
363
            f"llm_exponent={bp['llm_exponent']}",
364
        ]
365
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
366