Passed
Push — experiment-llm-ensemble-backen... ( fc5cf8...212700 )
by Juho
11:45
created

BaseLLMBackend.initialize()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Language model based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError
12
13
import annif.eval
14
import annif.parallel
15
import annif.util
16
from annif.exception import NotSupportedException
17
from annif.suggestion import SubjectSuggestion, SuggestionBatch
18
19
from . import backend, ensemble
20
21
# from openai import AsyncAzureOpenAI
22
23
24
if TYPE_CHECKING:
25
    from annif.corpus.document import DocumentCorpus
26
27
28
class BaseLLMBackend(backend.AnnifBackend):
29
    # """Base class for TODO backends"""
30
31
    DEFAULT_PARAMETERS = {
32
        "api_version": "2024-10-21",
33
        "temperature": 0.0,
34
        "top_p": 1.0,
35
        "seed": 0,
36
    }
37
38
    def initialize(self, parallel: bool = False) -> None:
39
        super().initialize(parallel)
40
        self.client = AzureOpenAI(
41
            azure_endpoint=self.params["endpoint"],
42
            api_version=self.params["api_version"],
43
            api_key=os.getenv("AZURE_OPENAI_KEY"),
44
        )
45
        # TODO: Verify the connection?
46
47
    def default_params(self):
48
        params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy()
49
        params.update(BaseLLMBackend.DEFAULT_PARAMETERS.copy())
50
        params.update(self.DEFAULT_PARAMETERS)
51
        return params
52
53
54
class LLMEnsembleBackend(BaseLLMBackend, ensemble.EnsembleBackend):
55
    # """TODO backend that combines results from multiple projects"""
56
57
    name = "llm_ensemble"
58
59
    DEFAULT_PARAMETERS = {
60
        "max_prompt_tokens": 127000,
61
        "llm_weight": 0.7,
62
        "labels_language": "en",
63
        "sources_limit": 10,
64
    }
65
66
    system_prompt = """
67
        You will be given text and a list of keywords to describe it. Your task is to
68
        score the keywords with a value between 0.0 and 1.0. The score value
69
        should depend on how well the keyword represents the text: a perfect
70
        keyword should have score 1.0 and completely unrelated keyword score
71
        0.0. You must output JSON with keywords as field names and add their scores
72
        as field values.
73
        There must be the same number of objects in the JSON as there are lines in the
74
        intput keyword list; do not skip scoring any keywords.
75
    """
76
    # Give zero or very low score to the keywords that do not describe the text.
77
78
    def get_hp_optimizer(self, corpus: DocumentCorpus, metric: str) -> None:
79
        raise NotSupportedException(
80
            "Hyperparameter optimization for LLM ensemble backend is not possible."
81
        )
82
83
    def _suggest_batch(
84
        self, texts: list[str], params: dict[str, Any]
85
    ) -> SuggestionBatch:
86
        sources = annif.util.parse_sources(params["sources"])
87
        llm_weight = float(params["llm_weight"])
88
        if llm_weight < 0.0 or llm_weight > 1.0:
89
            raise ValueError("llm_weight must be between 0.0 and 1.0")
90
91
        batch_by_source = self._suggest_with_sources(texts, sources)
92
        merged_source_batch = self._merge_source_batches(
93
            batch_by_source, sources, {"limit": params["sources_limit"]}
94
        )
95
96
        # Score the suggestion labels with the LLM
97
        llm_results_batch = self._llm_suggest_batch(texts, merged_source_batch, params)
98
99
        batches = [merged_source_batch, llm_results_batch]
100
        weights = [1.0 - llm_weight, llm_weight]
101
        return SuggestionBatch.from_averaged(batches, weights).filter(
102
            limit=int(params["limit"])
103
        )
104
105
    def _llm_suggest_batch(
106
        self,
107
        texts: list[str],
108
        suggestion_batch: SuggestionBatch,
109
        params: dict[str, Any],
110
    ) -> SuggestionBatch:
111
112
        model = params["model"]
113
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
114
        max_prompt_tokens = int(params["max_prompt_tokens"])
115
116
        labels_batch = self._get_labels_batch(suggestion_batch)
117
118
        llm_batch_suggestions = []
119
        for text, labels in zip(texts, labels_batch):
120
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
121
            text = self._truncate_text(text, encoding, max_prompt_tokens)
122
            prompt += "Here is the text:\n" + text + "\n"
123
124
            response = self._call_llm(prompt, model, params)
125
            try:
126
                llm_result = json.loads(response)
127
            except (TypeError, json.decoder.JSONDecodeError) as err:
128
                print(f"Error decoding JSON response from LLM: {response}")
129
                print(f"Error: {err}")
130
                llm_batch_suggestions.append(
131
                    [SubjectSuggestion(subject_id=None, score=0.0) for _ in labels]
132
                )
133
                continue
134
            llm_batch_suggestions.append(
135
                [
136
                    (
137
                        SubjectSuggestion(
138
                            subject_id=self.project.subjects.by_label(
139
                                llm_label, self.params["labels_language"]
140
                            ),
141
                            score=score,
142
                        )
143
                        if llm_label in labels
144
                        else SubjectSuggestion(subject_id=None, score=0.0)
145
                    )
146
                    for llm_label, score in llm_result.items()
147
                ]
148
            )
149
150
        return SuggestionBatch.from_sequence(
151
            llm_batch_suggestions,
152
            self.project.subjects,
153
        )
154
155
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
156
        return [
157
            [
158
                self.project.subjects[suggestion.subject_id].labels[
159
                    self.params["labels_language"]
160
                ]
161
                for suggestion in suggestion_result
162
            ]
163
            for suggestion_result in suggestion_batch
164
        ]
165
166
    def _truncate_text(self, text, encoding, max_prompt_tokens):
167
        """truncate text so it contains at most max_prompt_tokens according to the
168
        OpenAI tokenizer"""
169
        tokens = encoding.encode(text)
170
        return encoding.decode(tokens[:max_prompt_tokens])
171
172
    def _call_llm(self, prompt: str, model: str, params: dict[str, Any]) -> str:
173
        temperature = float(params["temperature"])
174
        top_p = float(params["top_p"])
175
        seed = int(params["seed"])
176
177
        messages = [
178
            {"role": "system", "content": self.system_prompt},
179
            {"role": "user", "content": prompt},
180
        ]
181
        try:
182
            completion = self.client.chat.completions.create(
183
                model=model,
184
                messages=messages,
185
                temperature=temperature,
186
                seed=seed,
187
                top_p=top_p,
188
                response_format={"type": "json_object"},
189
            )
190
191
            completion = completion.choices[0].message.content
192
            return completion
193
        except BadRequestError as err:  # openai.RateLimitError
194
            print(err)
195
            return "{}"
196