Passed
Pull Request — main (#859)
by Juho
06:43 queued 03:38
created

BaseLLMBackend._verify_connection()   A

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 1
dl 0
loc 12
rs 9.85
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
import textwrap
9
from typing import TYPE_CHECKING, Any, Optional
10
11
import tiktoken
12
from openai import AzureOpenAI, OpenAI, OpenAIError
13
from transformers import AutoTokenizer
14
15
import annif.eval
16
import annif.parallel
17
import annif.util
18
from annif.exception import ConfigurationException, OperationFailedException
19
from annif.suggestion import SubjectSuggestion, SuggestionBatch
20
21
from . import backend, ensemble, hyperopt
22
23
if TYPE_CHECKING:
24
    from annif.corpus.document import DocumentCorpus
25
26
27
class BaseLLMBackend(backend.AnnifBackend):
28
    """Base class for LLM backends"""
29
30
    _client = None
31
32
    DEFAULT_PARAMETERS = {
33
        "api_version": "2024-10-21",
34
        "temperature": 0.0,
35
        "top_p": 1.0,
36
        "seed": 0,
37
        "max_completion_tokens": 2000,
38
    }
39
40
    def initialize(self, parallel: bool = False) -> None:
41
        if self._client is not None:
42
            return
43
44
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
45
        api_base_url = os.getenv("LLM_API_BASE_URL")
46
        try:
47
            self.model = self.params["model"]
48
        except KeyError:
49
            raise ConfigurationException(
50
                "model setting is missing", project_id=self.project.project_id
51
            )
52
53
        if api_base_url is not None:
54
            self._client = OpenAI(
55
                base_url=api_base_url,
56
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
57
            )
58
        elif azure_endpoint is not None:
59
            self._client = AzureOpenAI(
60
                azure_endpoint=azure_endpoint,
61
                api_key=os.getenv("AZURE_OPENAI_KEY"),
62
                api_version=self.params["api_version"],
63
            )
64
        else:
65
            raise OperationFailedException(
66
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
67
                "environment variable for LLM API access."
68
            )
69
70
        # Tokenizer is unnecessary if truncation is not performed
71
        if int(self.params["max_prompt_tokens"]) > 0:
72
            self.tokenizer = self._get_tokenizer()
73
        self._verify_connection()
74
        super().initialize(parallel)
75
76
    def _get_tokenizer(self):
77
        try:
78
            # Try OpenAI tokenizer
79
            base_model = self.model.rsplit("-", 1)[0]
80
            return tiktoken.encoding_for_model(base_model)
81
        except KeyError:
82
            # Fallback to Hugging Face tokenizer
83
            return AutoTokenizer.from_pretrained(self.model)
84
85
    def _verify_connection(self):
86
        try:
87
            self._call_llm(
88
                system_prompt="You are a helpful assistant.",
89
                user_prompt="This is a test prompt to verify the connection.",
90
                params=self.params,
91
            )
92
        except OpenAIError as err:
93
            raise OperationFailedException(
94
                f"Failed to connect to LLM API: {err}"
95
            ) from err
96
        self.debug(f"connection successful to endpoint {self._client.base_url}")
97
98
    def default_params(self):
99
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
100
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
101
        params.update(self.DEFAULT_PARAMETERS)
102
        return params
103
104
    def _truncate_text(self, text, max_prompt_tokens):
105
        """Truncate text so it contains at most max_prompt_tokens according to the
106
        OpenAI tokenizer"""
107
        tokens = self.tokenizer.encode(text)
108
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
109
110
    def _call_llm(
111
        self,
112
        system_prompt: str,
113
        user_prompt: str,
114
        params: dict[str, Any],
115
        response_format: Optional[dict] = None,
116
    ) -> str:
117
        temperature = float(params["temperature"])
118
        top_p = float(params["top_p"])
119
        seed = int(params["seed"])
120
        max_completion_tokens = int(params["max_completion_tokens"])
121
122
        messages = [
123
            {"role": "system", "content": system_prompt},
124
            {"role": "user", "content": user_prompt},
125
        ]
126
        try:
127
            completion = self._client.chat.completions.create(
128
                model=self.model,
129
                messages=messages,
130
                temperature=temperature,
131
                seed=seed,
132
                top_p=top_p,
133
                max_completion_tokens=max_completion_tokens,
134
                response_format=response_format,
135
            )
136
        except OpenAIError as err:
137
            self.warning(
138
                "error calling LLM API; the base project scores are directly used. "
139
                f'API error: "{err}"'
140
            )
141
            return "{}"
142
        return completion.choices[0].message.content
143
144
145
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
146
    """Ensemble backend that combines results from multiple projects and scores them
147
    with a LLM"""
148
149
    name = "llm_ensemble"
150
151
    DEFAULT_PARAMETERS = {
152
        "max_prompt_tokens": 0,
153
        "llm_weight": 0.7,
154
        "llm_exponent": 1.0,
155
        "labels_language": "en",
156
        "sources_limit": 10,
157
    }
158
159
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
160
        return LLMEnsembleOptimizer(self, corpus, metric)
161
162
    def _suggest_batch(
163
        self, texts: list[str], params: dict[str, Any]
164
    ) -> SuggestionBatch:
165
        sources = annif.util.parse_sources(params["sources"])
166
        llm_weight = float(params["llm_weight"])
167
        llm_exponent = float(params["llm_exponent"])
168
        if llm_weight < 0.0 or llm_weight > 1.0:
169
            raise ValueError("llm_weight must be between 0.0 and 1.0")
170
        if llm_exponent < 0.0:
171
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
172
173
        batch_by_source = self._suggest_with_sources(texts, sources)
174
        merged_source_batch = self._merge_source_batches(
175
            batch_by_source, sources, {"limit": params["sources_limit"]}
176
        )
177
178
        # Score the suggestion labels with the LLM
179
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
180
181
        batches = [merged_source_batch, llm_results_batch]
182
        weights = [1.0 - llm_weight, llm_weight]
183
        exponents = [1.0, llm_exponent]
184
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
185
            limit=int(params["limit"])
186
        )
187
188
    def _llm_suggest_batch(
189
        self,
190
        texts: list[str],
191
        suggestion_batch: SuggestionBatch,
192
        params: dict[str, Any],
193
    ) -> SuggestionBatch:
194
195
        max_prompt_tokens = int(params["max_prompt_tokens"])
196
197
        system_prompt = textwrap.dedent(
198
            """\
199
            You will be given text and a list of keywords to describe it. Your task is
200
            to score the keywords with a value between 0 and 100. The score value
201
            should depend on how well the keyword represents the text: a perfect
202
            keyword should have score 100 and completely unrelated keyword score 0.
203
204
            You must output JSON with keywords as field names and add their scores as
205
            field values.
206
207
            There must be the same number of objects in the JSON as there are lines in
208
            the input keyword list; do not skip scoring any keywords.
209
        """
210
        )
211
212
        labels_batch = self._get_labels_batch(suggestion_batch)
213
214
        def process_single_prompt(text, labels):
215
            user_prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n\n"
216
            if max_prompt_tokens > 0:
217
                text = self._truncate_text(text, max_prompt_tokens)
218
219
            user_prompt += "Here is the text:\n" + text
220
            self.debug(f'LLM system prompt: "{system_prompt}"')
221
            self.debug(f'LLM user prompt: "{user_prompt}"')
222
223
            response = self._call_llm(
224
                system_prompt,
225
                user_prompt,
226
                params,
227
                response_format={"type": "json_object"},
228
            )
229
            try:
230
                llm_result = json.loads(response)
231
            except json.JSONDecodeError as err:
232
                start = max(err.pos - 100, 0)
233
                end = err.pos + 101  # Slicing out of bounds is ok
234
                snippet = response[start:end]
235
                self.warning(
236
                    f"Failed to decode JSON response from LLM.\n"
237
                    f"Error: {err}\n"
238
                    f"Context (around error position {err.pos}):\n"
239
                    f"...{snippet}..."
240
                )
241
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
242
            except TypeError as err:
243
                self.warning(
244
                    f"Failed to decode JSON response from LLM due to TypeError: {err}\n"
245
                )
246
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
247
            return [
248
                (
249
                    SubjectSuggestion(
250
                        subject_id=self.project.subjects.by_label(
251
                            llm_label, self.params["labels_language"]
252
                        ),
253
                        score=score / 100.0,  # LLM scores are between 0 and 100
254
                    )
255
                    if llm_label in labels
256
                    else SubjectSuggestion(subject_id=None, score=0.0)
257
                )
258
                for llm_label, score in llm_result.items()
259
            ]
260
261
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
262
            llm_batch_suggestions = list(
263
                executor.map(process_single_prompt, texts, labels_batch)
264
            )
265
266
        return SuggestionBatch.from_sequence(
267
            llm_batch_suggestions,
268
            self.project.subjects,
269
        )
270
271
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
272
        return [
273
            [
274
                self.project.subjects[suggestion.subject_id].labels[
275
                    self.params["labels_language"]
276
                ]
277
                for suggestion in suggestion_result
278
            ]
279
            for suggestion_result in suggestion_batch
280
        ]
281
282
283
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
284
    """Hyperparameter optimizer for the LLM ensemble backend"""
285
286
    def _prepare(self, n_jobs=1):
287
        sources = annif.util.parse_sources(self._backend.params["sources"])
288
        project_ids = [source[0] for source in sources]
289
        self._backend.initialize(parallel=True)
290
291
        psmap = annif.parallel.ProjectSuggestMap(
292
            self._backend.project.registry,
293
            project_ids,
294
            backend_params=None,
295
            limit=None,
296
            threshold=0.0,
297
        )
298
299
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
300
301
        self._gold_batches = []
302
        self._source_batches = []
303
304
        self.debug("Generating source batches")
305
        with pool_class(jobs) as pool:
306
            for suggestions_batch, gold_batch in pool.imap_unordered(
307
                psmap.suggest_batch, self._corpus.doc_batches
308
            ):
309
                self._source_batches.append(suggestions_batch)
310
                self._gold_batches.append(gold_batch)
311
312
        # get the llm batches
313
        self.debug("Generating LLM batches")
314
        self._merged_source_batches = []
315
        self._llm_batches = []
316
        for batch_by_source, docs_batch in zip(
317
            self._source_batches, self._corpus.doc_batches
318
        ):
319
            merged_source_batch = self._backend._merge_source_batches(
320
                batch_by_source,
321
                sources,
322
                {"limit": self._backend.params["sources_limit"]},
323
            )
324
            llm_batch = self._backend._llm_suggest_batch(
325
                [doc.text for doc in docs_batch],
326
                merged_source_batch,
327
                self._backend.params,
328
            )
329
            self._merged_source_batches.append(merged_source_batch)
330
            self._llm_batches.append(llm_batch)
331
332
    def _objective(self, trial) -> float:
333
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
334
        params = {
335
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
336
            "llm_exponent": trial.suggest_float("llm_exponent", 0.25, 10.0, log=True),
337
        }
338
        for merged_source_batch, llm_batch, gold_batch in zip(
339
            self._merged_source_batches, self._llm_batches, self._gold_batches
340
        ):
341
            batches = [merged_source_batch, llm_batch]
342
            weights = [
343
                1.0 - params["llm_weight"],
344
                params["llm_weight"],
345
            ]
346
            exponents = [
347
                1.0,
348
                params["llm_exponent"],
349
            ]
350
            avg_batch = SuggestionBatch.from_averaged(
351
                batches, weights, exponents
352
            ).filter(limit=int(self._backend.params["limit"]))
353
            eval_batch.evaluate_many(avg_batch, gold_batch)
354
        results = eval_batch.results(metrics=[self._metric])
355
        return results[self._metric]
356
357
    def _postprocess(self, study):
358
        bp = study.best_params
359
        lines = [
360
            f"llm_weight={bp['llm_weight']}",
361
            f"llm_exponent={bp['llm_exponent']}",
362
        ]
363
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
364