Passed
Push — experiment-llm-ensemble-backen... ( 3deddd...e8eaa0 )
by Juho
03:10
created

LLMEnsembleOptimizer._postprocess()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
from typing import TYPE_CHECKING, Any, Optional
8
9
import tiktoken
10
from openai import AzureOpenAI, BadRequestError, OpenAIError
11
12
import annif.eval
13
import annif.parallel
14
import annif.util
15
from annif.exception import OperationFailedException
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend, ensemble, hyperopt
19
20
if TYPE_CHECKING:
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class BaseLLMBackend(backend.AnnifBackend):
25
    """Base class for LLM backends"""
26
27
    DEFAULT_PARAMETERS = {
28
        "api_version": "2024-10-21",
29
        "temperature": 0.0,
30
        "top_p": 1.0,
31
        "seed": 0,
32
    }
33
34
    def initialize(self, parallel: bool = False) -> None:
35
        super().initialize(parallel)
36
        try:
37
            self.client = AzureOpenAI(
38
                azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
39
                api_key=os.getenv("AZURE_OPENAI_KEY"),
40
                api_version=self.params["api_version"],
41
            )
42
        except ValueError as err:
43
            raise OperationFailedException(
44
                f"Failed to connect to Azure endpoint: {err}"
45
            ) from err
46
        self._verify_connection()
47
48
    def _verify_connection(self):
49
        try:
50
            self._call_llm(
51
                system_prompt="You are a helpful assistant.",
52
                prompt="This is a test prompt to verify the connection.",
53
                model=self.params["model"],
54
                params=self.params,
55
            )
56
        except OpenAIError as err:
57
            raise OperationFailedException(
58
                f"Failed to connect to endpoint {self.params['endpoint']}: {err}"
59
            ) from err
60
        # print(f"Successfully connected to endpoint {self.params['endpoint']}")
61
62
    def default_params(self):
63
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
64
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
65
        params.update(self.DEFAULT_PARAMETERS)
66
        return params
67
68
    def _truncate_text(self, text, encoding, max_prompt_tokens):
69
        """Truncate text so it contains at most max_prompt_tokens according to the
70
        OpenAI tokenizer"""
71
        tokens = encoding.encode(text)
72
        return encoding.decode(tokens[:max_prompt_tokens])
73
74
    def _call_llm(
75
        self,
76
        system_prompt: str,
77
        prompt: str,
78
        model: str,
79
        params: dict[str, Any],
80
        response_format: Optional[dict] = None,
81
    ) -> str:
82
        temperature = float(params["temperature"])
83
        top_p = float(params["top_p"])
84
        seed = int(params["seed"])
85
86
        messages = [
87
            {"role": "system", "content": system_prompt},
88
            {"role": "user", "content": prompt},
89
        ]
90
        try:
91
            completion = self.client.chat.completions.create(
92
                model=model,
93
                messages=messages,
94
                temperature=temperature,
95
                seed=seed,
96
                top_p=top_p,
97
                response_format=response_format,
98
            )
99
        except BadRequestError as err:
100
            print(err)
101
            return "{}"
102
        return completion.choices[0].message.content
103
104
105
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
106
    """Ensemble backend that combines results from multiple projects and scores them
107
    with a LLM"""
108
109
    name = "llm_ensemble"
110
111
    DEFAULT_PARAMETERS = {
112
        "max_prompt_tokens": 127000,
113
        "llm_weight": 0.7,
114
        "labels_language": "en",
115
        "sources_limit": 10,
116
    }
117
118
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
119
        return LLMEnsembleOptimizer(self, corpus, metric)
120
121
    def _suggest_batch(
122
        self, texts: list[str], params: dict[str, Any]
123
    ) -> SuggestionBatch:
124
        sources = annif.util.parse_sources(params["sources"])
125
        llm_weight = float(params["llm_weight"])
126
        if llm_weight < 0.0 or llm_weight > 1.0:
127
            raise ValueError("llm_weight must be between 0.0 and 1.0")
128
129
        batch_by_source = self._suggest_with_sources(texts, sources)
130
        merged_source_batch = self._merge_source_batches(
131
            batch_by_source, sources, {"limit": params["sources_limit"]}
132
        )
133
134
        # Score the suggestion labels with the LLM
135
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
136
137
        batches = [merged_source_batch, llm_results_batch]
138
        weights = [1.0 - llm_weight, llm_weight]
139
        return SuggestionBatch.from_averaged(batches, weights).filter(
140
            limit=int(params["limit"])
141
        )
142
143
    def _llm_suggest_batch(
144
        self,
145
        texts: list[str],
146
        suggestion_batch: SuggestionBatch,
147
        params: dict[str, Any],
148
    ) -> SuggestionBatch:
149
150
        model = params["model"]
151
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
152
        max_prompt_tokens = int(params["max_prompt_tokens"])
153
154
        system_prompt = """
155
            You will be given text and a list of keywords to describe it. Your task is
156
            to score the keywords with a value between 0.0 and 1.0. The score value
157
            should depend on how well the keyword represents the text: a perfect
158
            keyword should have score 1.0 and completely unrelated keyword score
159
            0.0. You must output JSON with keywords as field names and add their scores
160
            as field values.
161
            There must be the same number of objects in the JSON as there are lines in
162
            the intput keyword list; do not skip scoring any keywords.
163
        """
164
165
        labels_batch = self._get_labels_batch(suggestion_batch)
166
167
        llm_batch_suggestions = []
168
        for text, labels in zip(texts, labels_batch):
169
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
170
            text = self._truncate_text(text, encoding, max_prompt_tokens)
171
            prompt += "Here is the text:\n" + text + "\n"
172
173
            response = self._call_llm(
174
                system_prompt,
175
                prompt,
176
                model,
177
                params,
178
                response_format={"type": "json_object"},
179
            )
180
            try:
181
                llm_result = json.loads(response)
182
            except (TypeError, json.decoder.JSONDecodeError) as err:
183
                print(f"Error decoding JSON response from LLM: {response}")
184
                print(f"Error: {err}")
185
                llm_batch_suggestions.append(
186
                    [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
187
                )
188
                continue
189
            llm_batch_suggestions.append(
190
                [
191
                    (
192
                        SubjectSuggestion(
193
                            subject_id=self.project.subjects.by_label(
194
                                llm_label, self.params["labels_language"]
195
                            ),
196
                            score=score,
197
                        )
198
                        if llm_label in labels
199
                        else SubjectSuggestion(subject_id=None, score=0.0)
200
                    )
201
                    for llm_label, score in llm_result.items()
202
                ]
203
            )
204
205
        return SuggestionBatch.from_sequence(
206
            llm_batch_suggestions,
207
            self.project.subjects,
208
        )
209
210
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
211
        return [
212
            [
213
                self.project.subjects[suggestion.subject_id].labels[
214
                    self.params["labels_language"]
215
                ]
216
                for suggestion in suggestion_result
217
            ]
218
            for suggestion_result in suggestion_batch
219
        ]
220
221
222
class LLMEnsembleOptimizer(ensemble.EnsembleOptimizer):
223
    """Hyperparameter optimizer for the LLM ensemble backend"""
224
225
    def _prepare(self, n_jobs=1):
226
        sources = dict(annif.util.parse_sources(self._backend.params["sources"]))
227
228
        # initialize the source projects before forking, to save memory
229
        # for project_id in sources.keys():
230
        #     project = self._backend.project.registry.get_project(project_id)
231
        #     project.initialize(parallel=True)
232
        self._backend.initialize(parallel=True)
233
234
        psmap = annif.parallel.ProjectSuggestMap(
235
            self._backend.project.registry,
236
            list(sources.keys()),
237
            backend_params=None,
238
            limit=None,
239
            threshold=0.0,
240
        )
241
242
        jobs, pool_class = annif.parallel.get_pool(n_jobs)
243
244
        self._gold_batches = []
245
        self._source_batches = []
246
247
        print("Generating source batches")
248
        with pool_class(jobs) as pool:
249
            for suggestions_batch, gold_batch in pool.imap_unordered(
250
                psmap.suggest_batch, self._corpus.doc_batches
251
            ):
252
                self._source_batches.append(suggestions_batch)
253
                self._gold_batches.append(gold_batch)
254
255
        # get the llm batches
256
        print("Generating LLM batches")
257
        self._merged_source_batches = []
258
        self._llm_batches = []
259
        for batch_by_source, docs_batch in zip(
260
            self._source_batches, self._corpus.doc_batches
261
        ):
262
            merged_source_batch = self._backend._merge_source_batches(
263
                batch_by_source,
264
                sources.items(),
265
                {"limit": self._backend.params["sources_limit"]},
266
            )
267
            llm_batch = self._backend._llm_suggest_batch(
268
                [doc.text for doc in docs_batch],
269
                merged_source_batch,
270
                self._backend.params,
271
            )
272
            self._merged_source_batches.append(merged_source_batch)
273
            self._llm_batches.append(llm_batch)
274
275
    def _objective(self, trial) -> float:
276
        eval_batch = annif.eval.EvaluationBatch(self._backend.project.subjects)
277
        params = {
278
            "llm_weight": trial.suggest_float("llm_weight", 0.0, 1.0),
279
        }
280
        print("Evaluating...")
281
        for merged_source_batch, llm_batch, gold_batch in zip(
282
            self._merged_source_batches, self._llm_batches, self._gold_batches
283
        ):
284
            batches = [merged_source_batch, llm_batch]
285
            weights = [1.0 - params["llm_weight"], params["llm_weight"]]
286
            avg_batch = SuggestionBatch.from_averaged(batches, weights).filter(
287
                limit=int(self._backend.params["limit"])
288
            )
289
            eval_batch.evaluate_many(avg_batch, gold_batch)
290
        results = eval_batch.results(metrics=[self._metric])
291
        return results[self._metric]
292
293
    def _postprocess(self, study):
294
        bp = study.best_params
295
        lines = [
296
            f"llm_weight={bp['llm_weight']}",
297
        ]
298
        return hyperopt.HPRecommendation(lines=lines, score=study.best_value)
299