Passed
Push — llms4subjects-natlibfi-germeva... ( 4c0010...10211c )
by Juho
03:53
created

LLMEnsembleOptimizer._objective()   A

Complexity

Conditions 2

Size

Total Lines 24
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 20
nop 2
dl 0
loc 24
rs 9.4
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import concurrent.futures
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any, Optional
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError, OpenAI, OpenAIError
12
13
import annif.eval
14
import annif.parallel
15
import annif.util
16
from annif.exception import ConfigurationException, OperationFailedException
17
from annif.suggestion import SubjectSuggestion, SuggestionBatch
18
19
from . import backend, ensemble, hyperopt
20
21
if TYPE_CHECKING:
22
    from annif.corpus.document import DocumentCorpus
23
24
25
class BaseLLMBackend(backend.AnnifBackend):
26
    """Base class for LLM backends"""
27
28
    DEFAULT_PARAMETERS = {
29
        "api_version": "2024-10-21",
30
        "temperature": 0.0,
31
        "top_p": 1.0,
32
        "seed": 0,
33
    }
34
35
    def initialize(self, parallel: bool = False) -> None:
36
        azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
37
        api_base_url = os.getenv("LLM_API_BASE_URL")
38
        if api_base_url is not None:
39
            self.client = OpenAI(
40
                base_url=api_base_url,
41
                api_key=os.getenv("LLM_API_KEY", "dummy-key"),
42
            )
43
        elif azure_endpoint is not None:
44
            self.client = AzureOpenAI(
45
                azure_endpoint=azure_endpoint,
46
                api_key=os.getenv("AZURE_OPENAI_KEY"),
47
                api_version=self.params["api_version"],
48
            )
49
        else:
50
            raise OperationFailedException(
51
                "Please set the AZURE_OPENAI_ENDPOINT or LLM_API_BASE_URL "
52
                "environment variable for LLM API access."
53
            )
54
        self._verify_connection()
55
        super().initialize(parallel)
56
57
    def _verify_connection(self):
58
        try:
59
            self._call_llm(
60
                system_prompt="You are a helpful assistant.",
61
                prompt="This is a test prompt to verify the connection.",
62
                model=self.params["model"],
63
                params=self.params,
64
            )
65
        except OpenAIError as err:
66
            raise OperationFailedException(
67
                f"Failed to connect to LLM API: {err}"
68
            ) from err
69
        except KeyError as err:
70
            if err.args[0] == 'model':
71
                raise ConfigurationException(
72
                    "model setting is missing", project_id=self.project.project_id
73
                )
74
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
75
76
    def default_params(self):
77
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
78
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
79
        params.update(self.DEFAULT_PARAMETERS)
80
        return params
81
82
    def _truncate_text(self, text, encoding, max_prompt_tokens):
83
        """Truncate text so it contains at most max_prompt_tokens according to the
84
        OpenAI tokenizer"""
85
        tokens = encoding.encode(text)
86
        return encoding.decode(tokens[:max_prompt_tokens])
87
88
    def _call_llm(
89
        self,
90
        system_prompt: str,
91
        prompt: str,
92
        model: str,
93
        params: dict[str, Any],
94
        response_format: Optional[dict] = None,
95
    ) -> str:
96
        temperature = float(params["temperature"])
97
        top_p = float(params["top_p"])
98
        seed = int(params["seed"])
99
100
        messages = [
101
            {"role": "system", "content": system_prompt},
102
            {"role": "user", "content": prompt},
103
        ]
104
        try:
105
            completion = self.client.chat.completions.create(
106
                model=model,
107
                messages=messages,
108
                temperature=temperature,
109
                seed=seed,
110
                top_p=top_p,
111
                response_format=response_format,
112
            )
113
        except BadRequestError as err:
114
            print(err)
115
            return "{}"
116
        return completion.choices[0].message.content
117
118
119
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
120
    """Ensemble backend that combines results from multiple projects and scores them
121
    with a LLM"""
122
123
    name = "llm_ensemble"
124
125
    DEFAULT_PARAMETERS = {
126
        "max_prompt_tokens": 127000,
127
        "llm_weight": 0.7,
128
        "llm_exponent": 1.0,
129
        "labels_language": "en",
130
        "sources_limit": 10,
131
    }
132
133
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
134
        return LLMEnsembleOptimizer(self, corpus, metric)
135
136
    def _suggest_batch(
137
        self, texts: list[str], params: dict[str, Any]
138
    ) -> SuggestionBatch:
139
        sources = annif.util.parse_sources(params["sources"])
140
        llm_weight = float(params["llm_weight"])
141
        llm_exponent = float(params["llm_exponent"])
142
        if llm_weight < 0.0 or llm_weight > 1.0:
143
            raise ValueError("llm_weight must be between 0.0 and 1.0")
144
        if llm_exponent < 0.0:
145
            raise ValueError("llm_weight_exp must be greater than or equal to 0.0")
146
147
        batch_by_source = self._suggest_with_sources(texts, sources)
148
        merged_source_batch = self._merge_source_batches(
149
            batch_by_source, sources, {"limit": params["sources_limit"]}
150
        )
151
152
        # Score the suggestion labels with the LLM
153
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
154
155
        batches = [merged_source_batch, llm_results_batch]
156
        weights = [1.0 - llm_weight, llm_weight]
157
        exponents = [1.0, llm_exponent]
158
        return SuggestionBatch.from_averaged(batches, weights, exponents).filter(
159
            limit=int(params["limit"])
160
        )
161
162
    def _llm_suggest_batch(
163
        self,
164
        texts: list[str],
165
        suggestion_batch: SuggestionBatch,
166
        params: dict[str, Any],
167
    ) -> SuggestionBatch:
168
169
        model = params["model"]
170
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
171
        max_prompt_tokens = int(params["max_prompt_tokens"])
172
173
        system_prompt = """
174
            You will be given text and a list of keywords to describe it. Your task is
175
            to score the keywords with a value between 0 and 100. The score value
176
            should depend on how well the keyword represents the text: a perfect
177
            keyword should have score 100 and completely unrelated keyword score
178
            0. You must output JSON with keywords as field names and add their scores
179
            as field values.
180
            There must be the same number of objects in the JSON as there are lines in
181
            the intput keyword list; do not skip scoring any keywords.
182
        """
183
184
        labels_batch = self._get_labels_batch(suggestion_batch)
185
186
        def process_single_prompt(text, labels):
187
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
188
            text = self._truncate_text(text, encoding, max_prompt_tokens)
189
            prompt += "Here is the text:\n" + text + "\n"
190
191
            response = self._call_llm(
192
                system_prompt,
193
                prompt,
194
                model,
195
                params,
196
                response_format={"type": "json_object"},
197
            )
198
            try:
199
                llm_result = json.loads(response)
200
            except (TypeError, json.decoder.JSONDecodeError) as err:
201
                print(f"Error decoding JSON response from LLM: {response}")
202
                print(f"Error: {err}")
203
                return [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
204
205
            return [
206
                (
207
                    SubjectSuggestion(
208
                        subject_id=self.project.subjects.by_label(
209
                            llm_label, self.params["labels_language"]
210
                        ),
211
                        score=score / 100.0,  # LLM scores are between 0 and 100
212
                    )
213
                    if llm_label in labels
214
                    else SubjectSuggestion(subject_id=None, score=0.0)
215
                )
216
                for llm_label, score in llm_result.items()
217
            ]
218
219
        with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
220
            llm_batch_suggestions = list(
221
                executor.map(process_single_prompt, texts, labels_batch)
222
            )
223
224
        return SuggestionBatch.from_sequence(
225
            llm_batch_suggestions,
226
            self.project.subjects,
227
        )
228
229
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
230
        return [
231
            [
232
                self.project.subjects[suggestion.subject_id].labels[
233
                    self.params["labels_language"]
234
                ]
235
                for suggestion in suggestion_result
236
            ]
237
            for suggestion_result in suggestion_batch
238
        ]
239
240
241
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
242
    """Hyperparameter optimizer for the LLM ensemble backend"""
243
244
    def _prepare(self, n_jobs=1):
245
        sources = dict(annif.util.parse_sources(self._backend.params["sources"]))
246
247
        # initialize the source projects before forking, to save memory
248
        # for project_id in sources.keys():
249
        #     project = self._backend.project.registry.get_project(project_id)
250
        #     project.initialize(parallel=True)
251
        self._backend.initialize(parallel=True)
252
253
        psmap = annif.parallel.ProjectSuggestMap(
254
            self._backend.project.registry,
255
            list(sources.keys()),
256
            backend_params=None,
257
            limit=None,
258
            threshold=0.0,
259
        )
260
261
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
262
263
        self._gold_batches = []
264
        self._source_batches = []
265
266
        print("Generating source batches")
267
        with pool_class(jobs) as pool:
268
            for suggestions_batch, gold_batch in pool.imap_unordered(
269
                psmap.suggest_batch, self._corpus.doc_batches
270
            ):
271
                self._source_batches.append(suggestions_batch)
272
                self._gold_batches.append(gold_batch)
273
274
        # get the llm batches
275
        print("Generating LLM batches")
276
        self._merged_source_batches = []
277
        self._llm_batches = []
278
        for batch_by_source, docs_batch in zip(
279
            self._source_batches, self._corpus.doc_batches
280
        ):
281
            merged_source_batch = self._backend._merge_source_batches(
282
                batch_by_source,
283
                sources.items(),
284
                {"limit": self._backend.params["sources_limit"]},
285
            )
286
            llm_batch = self._backend._llm_suggest_batch(
287
                [doc.text for doc in docs_batch],
288
                merged_source_batch,
289
                self._backend.params,
290
            )
291
            self._merged_source_batches.append(merged_source_batch)
292
            self._llm_batches.append(llm_batch)
293
294
    def _objective(self, trial) -> float:
295
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
296
        params = {
297
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
298
            "llm_exponent": trial.suggest_float("llm_exponent", 0.0, 30.0),
299
        }
300
        for merged_source_batch, llm_batch, gold_batch in zip(
301
            self._merged_source_batches, self._llm_batches, self._gold_batches
302
        ):
303
            batches = [merged_source_batch, llm_batch]
304
            weights = [
305
                1.0 - params["llm_weight"],
306
                params["llm_weight"],
307
            ]
308
            exponents = [
309
                1.0,
310
                params["llm_exponent"],
311
            ]
312
            avg_batch = SuggestionBatch.from_averaged(
313
                batches, weights, exponents
314
            ).filter(limit=int(self._backend.params["limit"]))
315
            eval_batch.evaluate_many(avg_batch, gold_batch)
316
        results = eval_batch.results(metrics=[self._metric])
317
        return results[self._metric]
318
319
    def _postprocess(self, study):
320
        bp = study.best_params
321
        lines = [
322
            f"llm_weight={bp['llm_weight']}",
323
            f"llm_exponent={bp['llm_exponent']}",
324
        ]
325
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
326