Passed
Push — experiment-llm-rescoring ( 90ef82...2fd311 )
by Juho
03:14
created

annif.backend.llm.LLMBackend._truncate_text()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
from typing import TYPE_CHECKING, Any
8
9
from openai import AzureOpenAI, BadRequestError
10
11
import annif.eval
12
import annif.parallel
13
import annif.util
14
from annif.exception import NotSupportedException
15
from annif.suggestion import SubjectSuggestion, SuggestionBatch
16
17
from . import backend
18
19
# from openai import AsyncAzureOpenAI
20
21
22
if TYPE_CHECKING:
23
    from datetime import datetime
24
25
    from annif.corpus.document import DocumentCorpus
26
27
28
class BaseLLMBackend(backend.AnnifBackend):
29
    # """Base class for TODO backends"""
30
31
32
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
33
        params = self._get_backend_params(None)
34
        sources = annif.util.parse_sources(params["sources"])
35
        return [
36
            getattr(self.project.registry.get_project(project_id), attr)
37
            for project_id, _ in sources
38
        ]
39
40
    def initialize(self, parallel: bool = False) -> None:
41
        # initialize all the source projects
42
        params = self._get_backend_params(None)
43
        for project_id, _ in annif.util.parse_sources(params["sources"]):
44
            project = self.project.registry.get_project(project_id)
45
            project.initialize(parallel)
46
47
        # self.client = AsyncAzureOpenAI(
48
        self.client = AzureOpenAI(
49
            azure_endpoint=params["endpoint"],
50
            api_key=os.getenv("AZURE_OPENAI_KEY"),
51
            api_version="2024-02-15-preview",
52
        )
53
54
    def _suggest_with_sources(
55
        self, texts: list[str], sources: list[tuple[str, float]]
56
    ) -> dict[str, SuggestionBatch]:
57
        return {
58
            project_id: self.project.registry.get_project(project_id).suggest(texts)
59
            for project_id, _ in sources
60
        }
61
62
63
class LLMBackend(BaseLLMBackend):
64
    # """TODO backend that combines results from multiple projects"""
65
66
    name = "llm"
67
68
    system_prompt = """
69
        You will be given text and a list of keywords to describe it. Your task is to
70
        score the keywords with a value between 0.0 and 1.0. The score value
71
        should depend on how well the keyword represents the text: a perfect
72
        keyword should have score 1.0 and completely unrelated keyword score
73
        0.0. You must output JSON with keywords as field names and add their scores
74
        as field values.
75
        There must be the same number of items in the JSON as there are in the
76
        intput keyword list.
77
    """
78
79
    MAX_PROMPT_TOKENS = 15000  # Typically full answer is ~500 tokens
80
81
    @property
82
    def is_trained(self) -> bool:
83
        sources_trained = self._get_sources_attribute("is_trained")
84
        return all(sources_trained)
85
86
    @property
87
    def modification_time(self) -> datetime | None:
88
        mtimes = self._get_sources_attribute("modification_time")
89
        return max(filter(None, mtimes), default=None)
90
91
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
92
        raise NotSupportedException("Training LLM backend is not possible.")
93
94
    def _suggest_batch(
95
        self, texts: list[str], params: dict[str, Any]
96
    ) -> SuggestionBatch:
97
        sources = annif.util.parse_sources(params["sources"])
98
        model = params["model"]
99
        llm_scores_weight = float(params["llm_scores_weight"])
100
        # llm_probs_weight = float(params["llm_probs_weight"])
101
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
102
103
        batch_results = []
104
        base_suggestion_batch = self._suggest_with_sources(texts, sources)[
105
            sources[0][0]
106
        ]
107
108
        for text, base_suggestions in zip(texts, base_suggestion_batch):
109
            text = self._truncate_text(text)
110
            prompt = "Here is the text:\n" + text[:chars_max] + "\n"
111
112
            base_labels = [
113
                self.project.subjects[s.subject_id].labels["en"]
114
                for s in base_suggestions
115
            ]
116
            prompt += "And here are the keywords:\n" + "\n".join(base_labels)
117
            answer, probabilities = self._call_llm(prompt, model)
118
            print(answer)
119
            print(probabilities)
120
            try:
121
                llm_result = json.loads(answer)
122
            except (TypeError, json.decoder.JSONDecodeError) as err:
123
                print(err)
124
                llm_result = dict()
125
            results = self._get_llm_suggestions(
126
                llm_result,
127
                base_labels,
128
                base_suggestions,
129
                llm_scores_weight,
130
                # probabilities,
131
                # llm_probs_weight,
132
            )
133
            batch_results.append(results)
134
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
135
136
    def _truncate_text(text):
137
        """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
138
        OpenAI tokenizer"""
139
        tokens = encoding.encode(text)
140
        return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
141
142
    def _get_llm_suggestions(
143
        self,
144
        llm_result,
145
        base_labels,
146
        base_suggestions,
147
        llm_scores_weight,
148
        # probabilities,
149
        # llm_probs_weight,
150
    ):
151
        suggestions = []
152
        # print(f"LLM result: {llm_result}")
153
        for blabel, bsuggestion in zip(base_labels, base_suggestions):
154
            # score = llm_result.get(blabel, 0)
155
            try:
156
                score = llm_result[blabel]
157
                # probability = probabilities[blabel]
158
            except KeyError:
159
                print(f"Base label {blabel} not found in LLM labels")
160
                score = bsuggestion.score  # use only base suggestion score
161
                # probability = 0.0
162
            subj_id = bsuggestion.subject_id
163
164
            base_scores_weight = 1.0 - llm_scores_weight
165
            mean_score = (
166
                base_scores_weight * bsuggestion.score
167
                + llm_scores_weight * score  # * probability * llm_probs_weight
168
            ) / (
169
                base_scores_weight + llm_scores_weight  # * probability * llm_probs_weight
170
            )  # weighted mean of LLM and base scores!
171
            suggestions.append(SubjectSuggestion(subject_id=subj_id, score=mean_score))
172
        return suggestions
173
174
    # async def _call_llm(self, prompt: str, model: str):
175
    def _call_llm(self, prompt: str, model: str):
176
        messages = [
177
            {"role": "system", "content": self.system_prompt},
178
            {"role": "user", "content": prompt},
179
        ]
180
        try:
181
            # completion = await client.chat.completions.create(
182
            completion = self.client.chat.completions.create(
183
                model=model,
184
                messages=messages,
185
                temperature=0.0,
186
                seed=0,
187
                max_tokens=1800,
188
                top_p=0.95,
189
                frequency_penalty=0,
190
                presence_penalty=0,
191
                stop=None,
192
                response_format={"type": "json_object"},
193
                # logprobs=True,
194
            )
195
            # return completion.choices[0].message.content
196
197
            answer = completion.choices[0].message.content
198
            # lines = self._get_logprobs(completion.choices[0].logprobs.content)
199
            # probs = self._get_probs(lines)
200
            # return answer, probs
201
            return answer, dict()
202
        except BadRequestError as err:  # openai.RateLimitError
203
            print(err)
204
            return "{}", dict()
205
206
    def _get_logprobs(self, content):
207
        import numpy as np
208
209
        lines = []
210
        joint_logprob = 0.0
211
        line = ""
212
        line_joint_logprob = 0.0
213
        for token in content:
214
            # print("Token:", token.token)
215
            # print("Log prob:", token.logprob)
216
            # print("Linear prob:", np.round(np.exp(token.logprob) * 100, 2), "%")
217
            # print("Bytes:", token.bytes, "\n")
218
            # aggregated_bytes += token.bytes
219
            joint_logprob += token.logprob
220
221
            line += token.token
222
            line_joint_logprob += token.logprob
223
            if "\n" in token.token:
224
                # print("Line is: "+ line)
225
                line_prob = np.exp(line_joint_logprob)
226
                # print("Line's linear prob:",  np.round(line_prob * 100, 2), "%")
227
228
                lines.append((line, line_prob))
229
                line = ""
230
                line_joint_logprob = 0.0
231
        #         print()
232
        # print()
233
        # print("Joint log prob:", joint_logprob)
234
        # print("Joint prob:", np.round(np.exp(joint_logprob) * 100, 2), "%")
235
        return lines
236
237
    # def _get_probs(self, lines):
238
    #     probs = dict()
239
    #     for line, prob in lines:
240
    #         try:
241
    #             label = line.split('"')[1]
242
    #         except IndexError:
243
    #             print("Failed parsing line: " + line)
244
    #             continue  # Not a line with label
245
    #         # probs[label] = 1.0
246
    #         probs[label] = prob
247
        return probs
248