Passed
Push — llm-ensemble-backend ( 5ea851 )
by Juho
08:01
created

annif.backend.llm.LLMEnsembleBackend._call_llm()   A

Complexity

Conditions 2

Size

Total Lines 25
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 21
nop 3
dl 0
loc 25
rs 9.376
c 0
b 0
f 0
1
"""Language model based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError
12
13
import annif.eval
14
import annif.parallel
15
import annif.util
16
from annif.exception import NotSupportedException
17
from annif.suggestion import SubjectSuggestion, SuggestionBatch
18
19
from . import backend, ensemble
20
21
# from openai import AsyncAzureOpenAI
22
23
24
if TYPE_CHECKING:
25
    from annif.corpus.document import DocumentCorpus
26
27
28
class BaseLLMBackend(backend.AnnifBackend):
29
    # """Base class for TODO backends"""
30
31
    def initialize(self, parallel: bool = False) -> None:
32
        # initialize all the source projects
33
        params = self._get_backend_params(None)
34
        for project_id, _ in annif.util.parse_sources(params["sources"]):
35
            project = self.project.registry.get_project(project_id)
36
            project.initialize(parallel)
37
38
        # self.client = AsyncAzureOpenAI(
39
        self.client = AzureOpenAI(
40
            azure_endpoint=params["endpoint"],
41
            api_key=os.getenv("AZURE_OPENAI_KEY"),
42
            api_version="2024-02-15-preview",
43
        )
44
45
46
class LLMEnsembleBackend(BaseLLMBackend, ensemble.BaseEnsembleBackend):
47
    # """TODO backend that combines results from multiple projects"""
48
49
    name = "llm_ensemble"
50
51
    system_prompt = """
52
        You will be given text and a list of keywords to describe it. Your task is to
53
        score the keywords with a value between 0.0 and 1.0. The score value
54
        should depend on how well the keyword represents the text: a perfect
55
        keyword should have score 1.0 and completely unrelated keyword score
56
        0.0. You must output JSON with keywords as field names and add their scores
57
        as field values.
58
        There must be the same number of objects in the JSON as there are lines in the
59
        intput keyword list; do not skip scoring any keywords.
60
    """
61
    # Give zero or very low score to the keywords that do not describe the text.
62
63
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
64
        raise NotSupportedException("Training LM ensemble backend is not possible.")
65
66
    def _suggest_batch(
67
        self, texts: list[str], params: dict[str, Any]
68
    ) -> SuggestionBatch:
69
        sources = annif.util.parse_sources(params["sources"])
70
        batch_by_source = self._suggest_with_sources(texts, sources)
71
        # Just add LLM scores to the base suggestions?
72
        return self._merge_source_batches(texts, batch_by_source, sources, params)
73
74
    def _merge_source_batches(
75
        self,
76
        texts: list[str],
77
        batch_by_source: dict[str, SuggestionBatch],
78
        sources: list[tuple[str, float]],
79
        params: dict[str, Any],
80
    ) -> SuggestionBatch:
81
        model = params["model"]
82
        # llm_scores_weight = float(params["llm_weight"])
83
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
84
85
        batches = [batch_by_source[project_id] for project_id, _ in sources]
86
        weights = [weight for _, weight in sources]
87
        avg_suggestion_batch = SuggestionBatch.from_averaged(batches, weights).filter(
88
            limit=int(params["limit"])  # TODO Increase limit
89
        )
90
91
        labels_batch = []
92
        for suggestionresult in avg_suggestion_batch:
93
            # print(suggestionresult)
94
            # for suggestion in suggestionresult:
95
            #     # print(suggestion)
96
            #     print(self.project.subjects[suggestion.subject_id].labels["en"])
97
            labels_batch.append(
98
                [
99
                    self.project.subjects[s.subject_id].labels[
100
                        "en"
101
                    ]  # TODO: make language selectable
102
                    for s in suggestionresult
103
                ]
104
            )
105
        # print(labels_batch)
106
107
        llm_batch_suggestions = []
108
        for text, labels in zip(texts, labels_batch):
109
            print(text)
110
            print(labels)
111
112
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
113
            text = self._truncate_text(text, encoding)
114
            prompt += "Here is the text:\n" + text + "\n"
115
116
            response = self._call_llm(prompt, model)
117
            print(response)
118
            try:
119
                llm_result = json.loads(response)
120
            except (TypeError, json.decoder.JSONDecodeError) as err:
121
                print(err)
122
                llm_result = None
123
                continue  # TODO: handle this error
124
            suggestions = []
125
            for label, score in llm_result.items():
126
                print(label, score)
127
                subj_id = self.project.subjects.by_label(
128
                    label, "en"
129
                )  # TODO: make language selectable
130
                # print(subj_id)
131
                suggestions.append(SubjectSuggestion(subject_id=subj_id, score=score))
132
133
            llm_batch_suggestions.append(suggestions)
134
135
        return SuggestionBatch.from_sequence(
136
            llm_batch_suggestions, self.project.subjects
137
        )
138
139
    def _truncate_text(self, text, encoding):
140
        """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
141
        OpenAI tokenizer"""
142
143
        MAX_PROMPT_TOKENS = 14000
144
        tokens = encoding.encode(text)
145
        return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
146
147
    def _call_llm(self, prompt: str, model: str):
148
        messages = [
149
            {"role": "system", "content": self.system_prompt},
150
            {"role": "user", "content": prompt},
151
        ]
152
        # print(prompt) #[-10000:])
153
        try:
154
            completion = self.client.chat.completions.create(
155
                model=model,
156
                messages=messages,
157
                temperature=0.0,
158
                seed=0,
159
                max_tokens=1800,
160
                top_p=0.95,
161
                frequency_penalty=0,
162
                presence_penalty=0,
163
                stop=None,
164
                response_format={"type": "json_object"},
165
            )
166
167
            completion = completion.choices[0].message.content
168
            return completion
169
        except BadRequestError as err:  # openai.RateLimitError
170
            print(err)
171
            return "{}"
172