Passed
Pull Request — main (#859)
by Juho
05:43 queued 03:00
created

BaseLLMBackend._verify_connection()   A

Complexity

Conditions 2

Size

Total Lines 12
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 1
dl 0
loc 12
rs 9.85
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, OpenAI, OpenAIError
12
from transformers import AutoTokenizer
13
14
import annif.eval
15
import annif.parallel
16
import annif.util
17
from annif.exception import ConfigurationException, OperationFailedException
18
from annif.suggestion import SubjectSuggestion, SuggestionBatch
19
20
from . import backend, ensemble, hyperopt
21
22
if TYPE_CHECKING:
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class BaseLLMBackend(backend.AnnifBackend):
27
    """Base class for LLM backends"""
28
29
    _client = None
30
31
    DEFAULT_PARAMETERS = {
32
        "api_version": "2024-10-21",
33
        "temperature": 0.0,
34
        "top_p": 1.0,
35
        "seed": 0,
36
        "max_completion_tokens": 2000,
37
    }
38
39
    def initialize(self, parallel: bool = False) -> None:
40
        if self._client is not None:
41
            return
42
43
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
44
        api_base_url = os.getenv("LLM_API_BASE_URL")
45
        try:
46
            self.model = self.params["model"]
47
        except KeyError:
48
            raise ConfigurationException(
49
                "model setting is missing", project_id=self.project.project_id
50
            )
51
52
        if api_base_url is not None:
53
            self._client = OpenAI(
54
                base_url=api_base_url,
55
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
56
            )
57
        elif azure_endpoint is not None:
58
            self._client = AzureOpenAI(
59
                azure_endpoint=azure_endpoint,
60
                api_key=os.getenv("AZURE_OPENAI_KEY"),
61
                api_version=self.params["api_version"],
62
            )
63
        else:
64
            raise OperationFailedException(
65
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
66
                "environment variable for LLM API access."
67
            )
68
69
        # Tokenizer is unnecessary if truncation is not performed
70
        if int(self.params["max_prompt_tokens"]) > 0:
71
            self.tokenizer = self._get_tokenizer()
72
        self._verify_connection()
73
        super().initialize(parallel)
74
75
    def _get_tokenizer(self):
76
        try:
77
            # Try OpenAI tokenizer
78
            base_model = self.model.rsplit("-", 1)[0]
79
            return tiktoken.encoding_for_model(base_model)
80
        except KeyError:
81
            # Fallback to Hugging Face tokenizer
82
            return AutoTokenizer.from_pretrained(self.model)
83
84
    def _verify_connection(self):
85
        try:
86
            self._call_llm(
87
                system_prompt="You are a helpful assistant.",
88
                prompt="This is a test prompt to verify the connection.",
89
                params=self.params,
90
            )
91
        except OpenAIError as err:
92
            raise OperationFailedException(
93
                f"Failed to connect to LLM API: {err}"
94
            ) from err
95
        self.debug(f"connection successful to endpoint {self._client.base_url}")
96
97
    def default_params(self):
98
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
99
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
100
        params.update(self.DEFAULT_PARAMETERS)
101
        return params
102
103
    def _truncate_text(self, text, max_prompt_tokens):
104
        """Truncate text so it contains at most max_prompt_tokens according to the
105
        OpenAI tokenizer"""
106
        tokens = self.tokenizer.encode(text)
107
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
108
109
    def _call_llm(
110
        self,
111
        system_prompt: str,
112
        prompt: str,
113
        params: dict[str, Any],
114
        response_format: Optional[dict] = None,
115
    ) -> str:
116
        temperature = float(params["temperature"])
117
        top_p = float(params["top_p"])
118
        seed = int(params["seed"])
119
        max_completion_tokens = int(params["max_completion_tokens"])
120
121
        messages = [
122
            {"role": "system", "content": system_prompt},
123
            {"role": "user", "content": prompt},
124
        ]
125
        try:
126
            completion = self._client.chat.completions.create(
127
                model=self.model,
128
                messages=messages,
129
                temperature=temperature,
130
                seed=seed,
131
                top_p=top_p,
132
                max_completion_tokens=max_completion_tokens,
133
                response_format=response_format,
134
            )
135
        except OpenAIError as err:
136
            self.warning(
137
                "error calling LLM API; the base project scores are directly used. "
138
                f'API error: "{err}"'
139
            )
140
            return "{}"
141
        return completion.choices[0].message.content
142
143
144
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
145
    """Ensemble backend that combines results from multiple projects and scores them
146
    with a LLM"""
147
148
    name = "llm_ensemble"
149
150
    DEFAULT_PARAMETERS = {
151
        "max_prompt_tokens": 0,
152
        "llm_weight": 0.7,
153
        "llm_exponent": 1.0,
154
        "labels_language": "en",
155
        "sources_limit": 10,
156
    }
157
158
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
159
        return LLMEnsembleOptimizer(self, corpus, metric)
160
161
    def _suggest_batch(
162
        self, texts: list[str], params: dict[str, Any]
163
    ) -> SuggestionBatch:
164
        sources = annif.util.parse_sources(params["sources"])
165
        llm_weight = float(params["llm_weight"])
166
        llm_exponent = float(params["llm_exponent"])
167
        if llm_weight < 0.0 or llm_weight > 1.0:
168
            raise ValueError("llm_weight must be between 0.0 and 1.0")
169
        if llm_exponent < 0.0:
170
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
171
172
        batch_by_source = self._suggest_with_sources(texts, sources)
173
        merged_source_batch = self._merge_source_batches(
174
            batch_by_source, sources, {"limit": params["sources_limit"]}
175
        )
176
177
        # Score the suggestion labels with the LLM
178
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
179
180
        batches = [merged_source_batch, llm_results_batch]
181
        weights = [1.0 - llm_weight, llm_weight]
182
        exponents = [1.0, llm_exponent]
183
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
184
            limit=int(params["limit"])
185
        )
186
187
    def _llm_suggest_batch(
188
        self,
189
        texts: list[str],
190
        suggestion_batch: SuggestionBatch,
191
        params: dict[str, Any],
192
    ) -> SuggestionBatch:
193
194
        max_prompt_tokens = int(params["max_prompt_tokens"])
195
196
        system_prompt = """
197
            You will be given text and a list of keywords to describe it. Your task is
198
            to score the keywords with a value between 0 and 100. The score value
199
            should depend on how well the keyword represents the text: a perfect
200
            keyword should have score 100 and completely unrelated keyword score
201
            0. You must output JSON with keywords as field names and add their scores
202
            as field values.
203
            There must be the same number of objects in the JSON as there are lines in
204
            the intput keyword list; do not skip scoring any keywords.
205
        """
206
207
        labels_batch = self._get_labels_batch(suggestion_batch)
208
209
        def process_single_prompt(text, labels):
210
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
211
            if max_prompt_tokens > 0:
212
                text = self._truncate_text(text, max_prompt_tokens)
213
            prompt += "Here is the text:\n" + text + "\n"
214
215
            response = self._call_llm(
216
                system_prompt,
217
                prompt,
218
                params,
219
                response_format={"type": "json_object"},
220
            )
221
            try:
222
                llm_result = json.loads(response)
223
            except json.JSONDecodeError as err:
224
                start = max(err.pos - 100, 0)
225
                end = err.pos + 101  # Slicing out of bounds is ok
226
                snippet = response[start:end]
227
                self.warning(
228
                    f"Failed to decode JSON response from LLM.\n"
229
                    f"Error: {err}\n"
230
                    f"Context (around error position {err.pos}):\n"
231
                    f"...{snippet}..."
232
                )
233
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
234
            except TypeError as err:
235
                self.warning(
236
                    f"Failed to decode JSON response from LLM due to TypeError: {err}\n"
237
                )
238
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
239
            return [
240
                (
241
                    SubjectSuggestion(
242
                        subject_id=self.project.subjects.by_label(
243
                            llm_label, self.params["labels_language"]
244
                        ),
245
                        score=score / 100.0,  # LLM scores are between 0 and 100
246
                    )
247
                    if llm_label in labels
248
                    else SubjectSuggestion(subject_id=None, score=0.0)
249
                )
250
                for llm_label, score in llm_result.items()
251
            ]
252
253
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
254
            llm_batch_suggestions = list(
255
                executor.map(process_single_prompt, texts, labels_batch)
256
            )
257
258
        return SuggestionBatch.from_sequence(
259
            llm_batch_suggestions,
260
            self.project.subjects,
261
        )
262
263
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
264
        return [
265
            [
266
                self.project.subjects[suggestion.subject_id].labels[
267
                    self.params["labels_language"]
268
                ]
269
                for suggestion in suggestion_result
270
            ]
271
            for suggestion_result in suggestion_batch
272
        ]
273
274
275
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
276
    """Hyperparameter optimizer for the LLM ensemble backend"""
277
278
    def _prepare(self, n_jobs=1):
279
        sources = annif.util.parse_sources(self._backend.params["sources"])
280
        project_ids = [source[0] for source in sources]
281
        self._backend.initialize(parallel=True)
282
283
        psmap = annif.parallel.ProjectSuggestMap(
284
            self._backend.project.registry,
285
            project_ids,
286
            backend_params=None,
287
            limit=None,
288
            threshold=0.0,
289
        )
290
291
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
292
293
        self._gold_batches = []
294
        self._source_batches = []
295
296
        self.debug("Generating source batches")
297
        with pool_class(jobs) as pool:
298
            for suggestions_batch, gold_batch in pool.imap_unordered(
299
                psmap.suggest_batch, self._corpus.doc_batches
300
            ):
301
                self._source_batches.append(suggestions_batch)
302
                self._gold_batches.append(gold_batch)
303
304
        # get the llm batches
305
        self.debug("Generating LLM batches")
306
        self._merged_source_batches = []
307
        self._llm_batches = []
308
        for batch_by_source, docs_batch in zip(
309
            self._source_batches, self._corpus.doc_batches
310
        ):
311
            merged_source_batch = self._backend._merge_source_batches(
312
                batch_by_source,
313
                sources,
314
                {"limit": self._backend.params["sources_limit"]},
315
            )
316
            llm_batch = self._backend._llm_suggest_batch(
317
                [doc.text for doc in docs_batch],
318
                merged_source_batch,
319
                self._backend.params,
320
            )
321
            self._merged_source_batches.append(merged_source_batch)
322
            self._llm_batches.append(llm_batch)
323
324
    def _objective(self, trial) -> float:
325
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
326
        params = {
327
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
328
            "llm_exponent": trial.suggest_float("llm_exponent", 0.25, 10.0, log=True),
329
        }
330
        for merged_source_batch, llm_batch, gold_batch in zip(
331
            self._merged_source_batches, self._llm_batches, self._gold_batches
332
        ):
333
            batches = [merged_source_batch, llm_batch]
334
            weights = [
335
                1.0 - params["llm_weight"],
336
                params["llm_weight"],
337
            ]
338
            exponents = [
339
                1.0,
340
                params["llm_exponent"],
341
            ]
342
            avg_batch = SuggestionBatch.from_averaged(
343
                batches, weights, exponents
344
            ).filter(limit=int(self._backend.params["limit"]))
345
            eval_batch.evaluate_many(avg_batch, gold_batch)
346
        results = eval_batch.results(metrics=[self._metric])
347
        return results[self._metric]
348
349
    def _postprocess(self, study):
350
        bp = study.best_params
351
        lines = [
352
            f"llm_weight={bp['llm_weight']}",
353
            f"llm_exponent={bp['llm_exponent']}",
354
        ]
355
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
356