Passed
Push — llms4subjects-natlibfi-germeva... ( 5ca761...e71b9c )
by Juho
03:31
created

BaseLLMBackend.initialize()   B

Complexity

Conditions 5

Size

Total Lines 33
Code Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 23
nop 2
dl 0
loc 33
rs 8.8613
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError, OpenAI, OpenAIError
12
from transformers import AutoTokenizer
13
14
import annif.eval
15
import annif.parallel
16
import annif.util
17
from annif.exception import ConfigurationException, OperationFailedException
18
from annif.suggestion import SubjectSuggestion, SuggestionBatch
19
20
from . import backend, ensemble, hyperopt
21
22
if TYPE_CHECKING:
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class BaseLLMBackend(backend.AnnifBackend):
27
    """Base class for LLM backends"""
28
29
    DEFAULT_PARAMETERS = {
30
        "api_version": "2024-10-21",
31
        "temperature": 0.0,
32
        "top_p": 1.0,
33
        "seed": 0,
34
    }
35
36
    def initialize(self, parallel: bool = False) -> None:
37
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
38
        api_base_url = os.getenv("LLM_API_BASE_URL")
39
40
        try:
41
            self.model = self.params["model"]
42
        except KeyError as err:
43
            raise ConfigurationException(
44
                "model setting is missing", project_id=self.project.project_id
45
            )
46
47
        if api_base_url is not None:
48
            self.client = OpenAI(
49
                base_url=api_base_url,
50
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
51
            )
52
        elif azure_endpoint is not None:
53
            self.client = AzureOpenAI(
54
                azure_endpoint=azure_endpoint,
55
                api_key=os.getenv("AZURE_OPENAI_KEY"),
56
                api_version=self.params["api_version"],
57
            )
58
        else:
59
            raise OperationFailedException(
60
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
61
                "environment variable for LLM API access."
62
            )
63
64
        # Tokenizer is unnecessary if truncation is not performed
65
        if int(self.params["max_prompt_tokens"]) > 0:
66
            self.tokenizer = self._get_tokenizer()
67
        self._verify_connection()
68
        super().initialize(parallel)
69
70
    def _get_tokenizer(self):
71
        try:
72
            # Try OpenAI tokenizer
73
            base_model = self.model.rsplit("-", 1)[0]
74
            return tiktoken.encoding_for_model(base_model)
75
        except KeyError:
76
            # Fallback to Hugging Face tokenizer
77
            return AutoTokenizer.from_pretrained(self.model)
78
79
    def _verify_connection(self):
80
        try:
81
            self._call_llm(
82
                system_prompt="You are a helpful assistant.",
83
                prompt="This is a test prompt to verify the connection.",
84
                params=self.params,
85
            )
86
        except OpenAIError as err:
87
            raise OperationFailedException(
88
                f"Failed to connect to LLM API: {err}"
89
            ) from err
90
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
91
92
    def default_params(self):
93
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
94
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
95
        params.update(self.DEFAULT_PARAMETERS)
96
        return params
97
98
    def _truncate_text(self, text, max_prompt_tokens):
99
        """Truncate text so it contains at most max_prompt_tokens according to the
100
        OpenAI tokenizer"""
101
        tokens = self.tokenizer.encode(text)
102
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
103
104
    def _call_llm(
105
        self,
106
        system_prompt: str,
107
        prompt: str,
108
        params: dict[str, Any],
109
        response_format: Optional[dict] = None,
110
    ) -> str:
111
        temperature = float(params["temperature"])
112
        top_p = float(params["top_p"])
113
        seed = int(params["seed"])
114
115
        messages = [
116
            {"role": "system", "content": system_prompt},
117
            {"role": "user", "content": prompt},
118
        ]
119
        try:
120
            completion = self.client.chat.completions.create(
121
                model=self.model,
122
                messages=messages,
123
                temperature=temperature,
124
                seed=seed,
125
                top_p=top_p,
126
                response_format=response_format,
127
            )
128
        except BadRequestError as err:
129
            print(err)
130
            return "{}"
131
        return completion.choices[0].message.content
132
133
134
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
135
    """Ensemble backend that combines results from multiple projects and scores them
136
    with a LLM"""
137
138
    name = "llm_ensemble"
139
140
    DEFAULT_PARAMETERS = {
141
        "max_prompt_tokens": 0,
142
        "llm_weight": 0.7,
143
        "llm_exponent": 1.0,
144
        "labels_language": "en",
145
        "sources_limit": 10,
146
    }
147
148
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
149
        return LLMEnsembleOptimizer(self, corpus, metric)
150
151
    def _suggest_batch(
152
        self, texts: list[str], params: dict[str, Any]
153
    ) -> SuggestionBatch:
154
        sources = annif.util.parse_sources(params["sources"])
155
        llm_weight = float(params["llm_weight"])
156
        llm_exponent = float(params["llm_exponent"])
157
        if llm_weight < 0.0 or llm_weight > 1.0:
158
            raise ValueError("llm_weight must be between 0.0 and 1.0")
159
        if llm_exponent < 0.0:
160
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
161
162
        batch_by_source = self._suggest_with_sources(texts, sources)
163
        merged_source_batch = self._merge_source_batches(
164
            batch_by_source, sources, {"limit": params["sources_limit"]}
165
        )
166
167
        # Score the suggestion labels with the LLM
168
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
169
170
        batches = [merged_source_batch, llm_results_batch]
171
        weights = [1.0 - llm_weight, llm_weight]
172
        exponents = [1.0, llm_exponent]
173
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
174
            limit=int(params["limit"])
175
        )
176
177
    def _llm_suggest_batch(
178
        self,
179
        texts: list[str],
180
        suggestion_batch: SuggestionBatch,
181
        params: dict[str, Any],
182
    ) -> SuggestionBatch:
183
184
        max_prompt_tokens = int(params["max_prompt_tokens"])
185
186
        system_prompt = """
187
            You will be given text and a list of keywords to describe it. Your task is
188
            to score the keywords with a value between 0 and 100. The score value
189
            should depend on how well the keyword represents the text: a perfect
190
            keyword should have score 100 and completely unrelated keyword score
191
            0. You must output JSON with keywords as field names and add their scores
192
            as field values.
193
            There must be the same number of objects in the JSON as there are lines in
194
            the intput keyword list; do not skip scoring any keywords.
195
        """
196
197
        labels_batch = self._get_labels_batch(suggestion_batch)
198
199
        def process_single_prompt(text, labels):
200
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
201
            if max_prompt_tokens > 0:
202
                text = self._truncate_text(text, max_prompt_tokens)
203
            prompt += "Here is the text:\n" + text + "\n"
204
205
            response = self._call_llm(
206
                system_prompt,
207
                prompt,
208
                params,
209
                response_format={"type": "json_object"},
210
            )
211
            try:
212
                llm_result = json.loads(response)
213
            except (TypeError, json.decoder.JSONDecodeError) as err:
214
                print(f"Error decoding JSON response from LLM: {response}")
215
                print(f"Error: {err}")
216
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
217
218
            return [
219
                (
220
                    SubjectSuggestion(
221
                        subject_id=self.project.subjects.by_label(
222
                            llm_label, self.params["labels_language"]
223
                        ),
224
                        score=score / 100.0,  # LLM scores are between 0 and 100
225
                    )
226
                    if llm_label in labels
227
                    else SubjectSuggestion(subject_id=None, score=0.0)
228
                )
229
                for llm_label, score in llm_result.items()
230
            ]
231
232
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
233
            llm_batch_suggestions = list(
234
                executor.map(process_single_prompt, texts, labels_batch)
235
            )
236
237
        return SuggestionBatch.from_sequence(
238
            llm_batch_suggestions,
239
            self.project.subjects,
240
        )
241
242
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
243
        return [
244
            [
245
                self.project.subjects[suggestion.subject_id].labels[
246
                    self.params["labels_language"]
247
                ]
248
                for suggestion in suggestion_result
249
            ]
250
            for suggestion_result in suggestion_batch
251
        ]
252
253
254
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
255
    """Hyperparameter optimizer for the LLM ensemble backend"""
256
257
    def _prepare(self, n_jobs=1):
258
        sources = dict(annif.util.parse_sources(self._backend.params["sources"]))
259
260
        # initialize the source projects before forking, to save memory
261
        # for project_id in sources.keys():
262
        #     project = self._backend.project.registry.get_project(project_id)
263
        #     project.initialize(parallel=True)
264
        self._backend.initialize(parallel=True)
265
266
        psmap = annif.parallel.ProjectSuggestMap(
267
            self._backend.project.registry,
268
            list(sources.keys()),
269
            backend_params=None,
270
            limit=None,
271
            threshold=0.0,
272
        )
273
274
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
275
276
        self._gold_batches = []
277
        self._source_batches = []
278
279
        print("Generating source batches")
280
        with pool_class(jobs) as pool:
281
            for suggestions_batch, gold_batch in pool.imap_unordered(
282
                psmap.suggest_batch, self._corpus.doc_batches
283
            ):
284
                self._source_batches.append(suggestions_batch)
285
                self._gold_batches.append(gold_batch)
286
287
        # get the llm batches
288
        print("Generating LLM batches")
289
        self._merged_source_batches = []
290
        self._llm_batches = []
291
        for batch_by_source, docs_batch in zip(
292
            self._source_batches, self._corpus.doc_batches
293
        ):
294
            merged_source_batch = self._backend._merge_source_batches(
295
                batch_by_source,
296
                sources.items(),
297
                {"limit": self._backend.params["sources_limit"]},
298
            )
299
            llm_batch = self._backend._llm_suggest_batch(
300
                [doc.text for doc in docs_batch],
301
                merged_source_batch,
302
                self._backend.params,
303
            )
304
            self._merged_source_batches.append(merged_source_batch)
305
            self._llm_batches.append(llm_batch)
306
307
    def _objective(self, trial) -> float:
308
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
309
        params = {
310
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
311
            "llm_exponent": trial.suggest_float("llm_exponent", 0.01, 30.0, log=True),
312
        }
313
        for merged_source_batch, llm_batch, gold_batch in zip(
314
            self._merged_source_batches, self._llm_batches, self._gold_batches
315
        ):
316
            batches = [merged_source_batch, llm_batch]
317
            weights = [
318
                1.0 - params["llm_weight"],
319
                params["llm_weight"],
320
            ]
321
            exponents = [
322
                1.0,
323
                params["llm_exponent"],
324
            ]
325
            avg_batch = SuggestionBatch.from_averaged(
326
                batches, weights, exponents
327
            ).filter(limit=int(self._backend.params["limit"]))
328
            eval_batch.evaluate_many(avg_batch, gold_batch)
329
        results = eval_batch.results(metrics=[self._metric])
330
        return results[self._metric]
331
332
    def _postprocess(self, study):
333
        bp = study.best_params
334
        lines = [
335
            f"llm_weight={bp['llm_weight']}",
336
            f"llm_exponent={bp['llm_exponent']}",
337
        ]
338
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
339