Passed
Push — experiment-llm-ensemble-backen... ( 212700...3deddd )
by Juho
03:29
created

BaseLLMBackend._call_llm()   A

Complexity

Conditions 2

Size

Total Lines 29
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 25
nop 6
dl 0
loc 29
rs 9.28
c 0
b 0
f 0
1
"""Backend utilizing a large-language model."""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
from typing import TYPE_CHECKING, Any, Optional
8
9
import tiktoken
10
from openai import AzureOpenAI, BadRequestError, OpenAIError
11
12
import annif.eval
13
import annif.parallel
14
import annif.util
15
from annif.exception import NotSupportedException, OperationFailedException
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend, ensemble
19
20
if TYPE_CHECKING:
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class BaseLLMBackend(backend.AnnifBackend):
25
    """Base class for LLM backends"""
26
27
    DEFAULT_PARAMETERS = {
28
        "api_version": "2024-10-21",
29
        "temperature": 0.0,
30
        "top_p": 1.0,
31
        "seed": 0,
32
    }
33
34
    def initialize(self, parallel: bool = False) -> None:
35
        super().initialize(parallel)
36
        self.client = AzureOpenAI(
37
            azure_endpoint=self.params["endpoint"],
38
            api_version=self.params["api_version"],
39
            api_key=os.getenv("AZURE_OPENAI_KEY"),
40
        )
41
        self._verify_connection()
42
43
    def _verify_connection(self):
44
        try:
45
            self._call_llm(
46
                system_prompt="You are a helpful assistant.",
47
                prompt="This is a test prompt to verify the connection.",
48
                model=self.params["model"],
49
                params=self.params,
50
            )
51
        except OpenAIError as err:
52
            raise OperationFailedException(
53
                f"Failed to connect to endpoint {self.params['endpoint']}: {err}"
54
            ) from err
55
        print(f"Successfully connected to endpoint {self.params['endpoint']}")
56
57
    def default_params(self):
58
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
59
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
60
        params.update(self.DEFAULT_PARAMETERS)
61
        return params
62
63
    def _truncate_text(self, text, encoding, max_prompt_tokens):
64
        """Truncate text so it contains at most max_prompt_tokens according to the
65
        OpenAI tokenizer"""
66
        tokens = encoding.encode(text)
67
        return encoding.decode(tokens[:max_prompt_tokens])
68
69
    def _call_llm(
70
        self,
71
        system_prompt: str,
72
        prompt: str,
73
        model: str,
74
        params: dict[str, Any],
75
        response_format: Optional[dict] = None,
76
    ) -> str:
77
        temperature = float(params["temperature"])
78
        top_p = float(params["top_p"])
79
        seed = int(params["seed"])
80
81
        messages = [
82
            {"role": "system", "content": system_prompt},
83
            {"role": "user", "content": prompt},
84
        ]
85
        try:
86
            completion = self.client.chat.completions.create(
87
                model=model,
88
                messages=messages,
89
                temperature=temperature,
90
                seed=seed,
91
                top_p=top_p,
92
                response_format=response_format,
93
            )
94
        except BadRequestError as err:
95
            print(err)
96
            return "{}"
97
        return completion.choices[0].message.content
98
99
100
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
101
    """Ensemble backend that combines results from multiple projects and scores them
102
    with a LLM"""
103
104
    name = "llm_ensemble"
105
106
    DEFAULT_PARAMETERS = {
107
        "max_prompt_tokens": 127000,
108
        "llm_weight": 0.7,
109
        "labels_language": "en",
110
        "sources_limit": 10,
111
    }
112
113
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
114
        raise NotSupportedException(
115
            "Hyperparameter optimization for LLM ensemble backend is not possible."
116
        )
117
118
    def _suggest_batch(
119
        self, texts: list[str], params: dict[str, Any]
120
    ) -> SuggestionBatch:
121
        sources = annif.util.parse_sources(params["sources"])
122
        llm_weight = float(params["llm_weight"])
123
        if llm_weight < 0.0 or llm_weight > 1.0:
124
            raise ValueError("llm_weight must be between 0.0 and 1.0")
125
126
        batch_by_source = self._suggest_with_sources(texts, sources)
127
        merged_source_batch = self._merge_source_batches(
128
            batch_by_source, sources, {"limit": params["sources_limit"]}
129
        )
130
131
        # Score the suggestion labels with the LLM
132
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
133
134
        batches = [merged_source_batch, llm_results_batch]
135
        weights = [1.0 - llm_weight, llm_weight]
136
        return SuggestionBatch.from_averaged(batches, weights).filter(
137
            limit=int(params["limit"])
138
        )
139
140
    def _llm_suggest_batch(
141
        self,
142
        texts: list[str],
143
        suggestion_batch: SuggestionBatch,
144
        params: dict[str, Any],
145
    ) -> SuggestionBatch:
146
147
        model = params["model"]
148
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
149
        max_prompt_tokens = int(params["max_prompt_tokens"])
150
151
        system_prompt = """
152
            You will be given text and a list of keywords to describe it. Your task is
153
            to score the keywords with a value between 0.0 and 1.0. The score value
154
            should depend on how well the keyword represents the text: a perfect
155
            keyword should have score 1.0 and completely unrelated keyword score
156
            0.0. You must output JSON with keywords as field names and add their scores
157
            as field values.
158
            There must be the same number of objects in the JSON as there are lines in
159
            the intput keyword list; do not skip scoring any keywords.
160
        """
161
162
        labels_batch = self._get_labels_batch(suggestion_batch)
163
164
        llm_batch_suggestions = []
165
        for text, labels in zip(texts, labels_batch):
166
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
167
            text = self._truncate_text(text, encoding, max_prompt_tokens)
168
            prompt += "Here is the text:\n" + text + "\n"
169
170
            response = self._call_llm(
171
                system_prompt,
172
                prompt,
173
                model,
174
                params,
175
                response_format={"type": "json_object"},
176
            )
177
            try:
178
                llm_result = json.loads(response)
179
            except (TypeError, json.decoder.JSONDecodeError) as err:
180
                print(f"Error decoding JSON response from LLM: {response}")
181
                print(f"Error: {err}")
182
                llm_batch_suggestions.append(
183
                    [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
184
                )
185
                continue
186
            llm_batch_suggestions.append(
187
                [
188
                    (
189
                        SubjectSuggestion(
190
                            subject_id=self.project.subjects.by_label(
191
                                llm_label, self.params["labels_language"]
192
                            ),
193
                            score=score,
194
                        )
195
                        if llm_label in labels
196
                        else SubjectSuggestion(subject_id=None, score=0.0)
197
                    )
198
                    for llm_label, score in llm_result.items()
199
                ]
200
            )
201
202
        return SuggestionBatch.from_sequence(
203
            llm_batch_suggestions,
204
            self.project.subjects,
205
        )
206
207
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
208
        return [
209
            [
210
                self.project.subjects[suggestion.subject_id].labels[
211
                    self.params["labels_language"]
212
                ]
213
                for suggestion in suggestion_result
214
            ]
215
            for suggestion_result in suggestion_batch
216
        ]
217