Passed
Push — experiment-llm-ensemble-backen... ( 340df7...bee7b8 )
by Juho
06:46
created

BaseLLMBackend.initialize()   B

Complexity

Conditions 6

Size

Total Lines 35
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 25
nop 2
dl 0
loc 35
rs 8.3466
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError, OpenAI, OpenAIError
12
from transformers import AutoTokenizer
13
14
import annif.eval
15
import annif.parallel
16
import annif.util
17
from annif.exception import ConfigurationException, OperationFailedException
18
from annif.suggestion import SubjectSuggestion, SuggestionBatch
19
20
from . import backend, ensemble, hyperopt
21
22
if TYPE_CHECKING:
23
    from annif.corpus.document import DocumentCorpus
24
25
26
class BaseLLMBackend(backend.AnnifBackend):
27
    """Base class for LLM backends"""
28
29
    _client = None
30
31
    DEFAULT_PARAMETERS = {
32
        "api_version": "2024-10-21",
33
        "temperature": 0.0,
34
        "top_p": 1.0,
35
        "seed": 0,
36
        "max_completion_tokens": 2000,
37
    }
38
39
    def initialize(self, parallel: bool = False) -> None:
40
        if self._client is not None:
41
            return
42
43
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
44
        api_base_url = os.getenv("LLM_API_BASE_URL")
45
        try:
46
            self.model = self.params["model"]
47
        except KeyError as err:
48
            raise ConfigurationException(
49
                "model setting is missing", project_id=self.project.project_id
50
            )
51
52
        if api_base_url is not None:
53
            self._client = OpenAI(
54
                base_url=api_base_url,
55
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
56
            )
57
        elif azure_endpoint is not None:
58
            self._client = AzureOpenAI(
59
                azure_endpoint=azure_endpoint,
60
                api_key=os.getenv("AZURE_OPENAI_KEY"),
61
                api_version=self.params["api_version"],
62
            )
63
        else:
64
            raise OperationFailedException(
65
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
66
                "environment variable for LLM API access."
67
            )
68
69
        # Tokenizer is unnecessary if truncation is not performed
70
        if int(self.params["max_prompt_tokens"]) > 0:
71
            self.tokenizer = self._get_tokenizer()
72
        self._verify_connection()
73
        super().initialize(parallel)
74
75
    def _get_tokenizer(self):
76
        try:
77
            # Try OpenAI tokenizer
78
            base_model = self.model.rsplit("-", 1)[0]
79
            return tiktoken.encoding_for_model(base_model)
80
        except KeyError:
81
            # Fallback to Hugging Face tokenizer
82
            return AutoTokenizer.from_pretrained(self.model)
83
84
    def _verify_connection(self):
85
        try:
86
            self._call_llm(
87
                system_prompt="You are a helpful assistant.",
88
                prompt="This is a test prompt to verify the connection.",
89
                params=self.params,
90
            )
91
        except OpenAIError as err:
92
            raise OperationFailedException(
93
                f"Failed to connect to LLM API: {err}"
94
            ) from err
95
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
96
97
    def default_params(self):
98
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
99
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
100
        params.update(self.DEFAULT_PARAMETERS)
101
        return params
102
103
    def _truncate_text(self, text, max_prompt_tokens):
104
        """Truncate text so it contains at most max_prompt_tokens according to the
105
        OpenAI tokenizer"""
106
        tokens = self.tokenizer.encode(text)
107
        return self.tokenizer.decode(tokens[:max_prompt_tokens])
108
109
    def _call_llm(
110
        self,
111
        system_prompt: str,
112
        prompt: str,
113
        params: dict[str, Any],
114
        response_format: Optional[dict] = None,
115
    ) -> str:
116
        temperature = float(params["temperature"])
117
        top_p = float(params["top_p"])
118
        seed = int(params["seed"])
119
        max_completion_tokens = int(params["max_completion_tokens"])
120
121
        messages = [
122
            {"role": "system", "content": system_prompt},
123
            {"role": "user", "content": prompt},
124
        ]
125
        try:
126
            completion = self._client.chat.completions.create(
127
                model=self.model,
128
                messages=messages,
129
                temperature=temperature,
130
                seed=seed,
131
                top_p=top_p,
132
                max_completion_tokens=max_completion_tokens,
133
                response_format=response_format,
134
            )
135
        except OpenAIError as err:
136
            print(err)
137
            return "{}"
138
        return completion.choices[0].message.content
139
140
141
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
142
    """Ensemble backend that combines results from multiple projects and scores them
143
    with a LLM"""
144
145
    name = "llm_ensemble"
146
147
    DEFAULT_PARAMETERS = {
148
        "max_prompt_tokens": 0,
149
        "llm_weight": 0.7,
150
        "llm_exponent": 1.0,
151
        "labels_language": "en",
152
        "sources_limit": 10,
153
    }
154
155
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
156
        return LLMEnsembleOptimizer(self, corpus, metric)
157
158
    def _suggest_batch(
159
        self, texts: list[str], params: dict[str, Any]
160
    ) -> SuggestionBatch:
161
        sources = annif.util.parse_sources(params["sources"])
162
        llm_weight = float(params["llm_weight"])
163
        llm_exponent = float(params["llm_exponent"])
164
        if llm_weight < 0.0 or llm_weight > 1.0:
165
            raise ValueError("llm_weight must be between 0.0 and 1.0")
166
        if llm_exponent < 0.0:
167
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
168
169
        batch_by_source = self._suggest_with_sources(texts, sources)
170
        merged_source_batch = self._merge_source_batches(
171
            batch_by_source, sources, {"limit": params["sources_limit"]}
172
        )
173
174
        # Score the suggestion labels with the LLM
175
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
176
177
        batches = [merged_source_batch, llm_results_batch]
178
        weights = [1.0 - llm_weight, llm_weight]
179
        exponents = [1.0, llm_exponent]
180
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
181
            limit=int(params["limit"])
182
        )
183
184
    def _llm_suggest_batch(
185
        self,
186
        texts: list[str],
187
        suggestion_batch: SuggestionBatch,
188
        params: dict[str, Any],
189
    ) -> SuggestionBatch:
190
191
        max_prompt_tokens = int(params["max_prompt_tokens"])
192
193
        system_prompt = """
194
            You will be given text and a list of keywords to describe it. Your task is
195
            to score the keywords with a value between 0 and 100. The score value
196
            should depend on how well the keyword represents the text: a perfect
197
            keyword should have score 100 and completely unrelated keyword score
198
            0. You must output JSON with keywords as field names and add their scores
199
            as field values.
200
            There must be the same number of objects in the JSON as there are lines in
201
            the intput keyword list; do not skip scoring any keywords.
202
        """
203
204
        labels_batch = self._get_labels_batch(suggestion_batch)
205
206
        def process_single_prompt(text, labels):
207
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
208
            if max_prompt_tokens > 0:
209
                text = self._truncate_text(text, max_prompt_tokens)
210
            prompt += "Here is the text:\n" + text + "\n"
211
212
            response = self._call_llm(
213
                system_prompt,
214
                prompt,
215
                params,
216
                response_format={"type": "json_object"},
217
            )
218
            try:
219
                llm_result = json.loads(response)
220
            except (TypeError, json.decoder.JSONDecodeError) as err:
221
                print(f"Error decoding JSON response from LLM: '{response[:100]}...'")
222
                print(f"{str(err)}")
223
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
224
225
            return [
226
                (
227
                    SubjectSuggestion(
228
                        subject_id=self.project.subjects.by_label(
229
                            llm_label, self.params["labels_language"]
230
                        ),
231
                        score=score / 100.0,  # LLM scores are between 0 and 100
232
                    )
233
                    if llm_label in labels
234
                    else SubjectSuggestion(subject_id=None, score=0.0)
235
                )
236
                for llm_label, score in llm_result.items()
237
            ]
238
239
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
240
            llm_batch_suggestions = list(
241
                executor.map(process_single_prompt, texts, labels_batch)
242
            )
243
244
        return SuggestionBatch.from_sequence(
245
            llm_batch_suggestions,
246
            self.project.subjects,
247
        )
248
249
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
250
        return [
251
            [
252
                self.project.subjects[suggestion.subject_id].labels[
253
                    self.params["labels_language"]
254
                ]
255
                for suggestion in suggestion_result
256
            ]
257
            for suggestion_result in suggestion_batch
258
        ]
259
260
261
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
262
    """Hyperparameter optimizer for the LLM ensemble backend"""
263
264
    def _prepare(self, n_jobs=1):
265
        sources = dict(annif.util.parse_sources(self._backend.params["sources"]))
266
267
        # initialize the source projects before forking, to save memory
268
        # for project_id in sources.keys():
269
        #     project = self._backend.project.registry.get_project(project_id)
270
        #     project.initialize(parallel=True)
271
        self._backend.initialize(parallel=True)
272
273
        psmap = annif.parallel.ProjectSuggestMap(
274
            self._backend.project.registry,
275
            list(sources.keys()),
276
            backend_params=None,
277
            limit=None,
278
            threshold=0.0,
279
        )
280
281
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
282
283
        self._gold_batches = []
284
        self._source_batches = []
285
286
        print("Generating source batches")
287
        with pool_class(jobs) as pool:
288
            for suggestions_batch, gold_batch in pool.imap_unordered(
289
                psmap.suggest_batch, self._corpus.doc_batches
290
            ):
291
                self._source_batches.append(suggestions_batch)
292
                self._gold_batches.append(gold_batch)
293
294
        # get the llm batches
295
        print("Generating LLM batches")
296
        self._merged_source_batches = []
297
        self._llm_batches = []
298
        for batch_by_source, docs_batch in zip(
299
            self._source_batches, self._corpus.doc_batches
300
        ):
301
            merged_source_batch = self._backend._merge_source_batches(
302
                batch_by_source,
303
                sources.items(),
304
                {"limit": self._backend.params["sources_limit"]},
305
            )
306
            llm_batch = self._backend._llm_suggest_batch(
307
                [doc.text for doc in docs_batch],
308
                merged_source_batch,
309
                self._backend.params,
310
            )
311
            self._merged_source_batches.append(merged_source_batch)
312
            self._llm_batches.append(llm_batch)
313
314
    def _objective(self, trial) -> float:
315
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
316
        params = {
317
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
318
            "llm_exponent": trial.suggest_float("llm_exponent", 0.01, 30.0, log=True),
319
        }
320
        for merged_source_batch, llm_batch, gold_batch in zip(
321
            self._merged_source_batches, self._llm_batches, self._gold_batches
322
        ):
323
            batches = [merged_source_batch, llm_batch]
324
            weights = [
325
                1.0 - params["llm_weight"],
326
                params["llm_weight"],
327
            ]
328
            exponents = [
329
                1.0,
330
                params["llm_exponent"],
331
            ]
332
            avg_batch = SuggestionBatch.from_averaged(
333
                batches, weights, exponents
334
            ).filter(limit=int(self._backend.params["limit"]))
335
            eval_batch.evaluate_many(avg_batch, gold_batch)
336
        results = eval_batch.results(metrics=[self._metric])
337
        return results[self._metric]
338
339
    def _postprocess(self, study):
340
        bp = study.best_params
341
        lines = [
342
            f"llm_weight={bp['llm_weight']}",
343
            f"llm_exponent={bp['llm_exponent']}",
344
        ]
345
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
346