Passed
Push — experiment-llm-rescoring-from-... ( 5179d6 )
by Juho
11:40 queued 06:59
created

annif.backend.llm.LLMBackend._parse_line()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 8
nop 2
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import os
6
from typing import TYPE_CHECKING, Any
7
8
import numpy as np
9
import tiktoken
10
from openai import AzureOpenAI, BadRequestError
11
12
import annif.eval
13
import annif.parallel
14
import annif.util
15
from annif.exception import NotSupportedException
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend
19
20
# from openai import AsyncAzureOpenAI
21
22
23
if TYPE_CHECKING:
24
    from datetime import datetime
25
26
    from annif.corpus.document import DocumentCorpus
27
28
29
class BaseLLMBackend(backend.AnnifBackend):
30
    # """Base class for TODO backends"""
31
32
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
33
        params = self._get_backend_params(None)
34
        sources = annif.util.parse_sources(params["sources"])
35
        return [
36
            getattr(self.project.registry.get_project(project_id), attr)
37
            for project_id, _ in sources
38
        ]
39
40
    def initialize(self, parallel: bool = False) -> None:
41
        # initialize all the source projects
42
        params = self._get_backend_params(None)
43
        for project_id, _ in annif.util.parse_sources(params["sources"]):
44
            project = self.project.registry.get_project(project_id)
45
            project.initialize(parallel)
46
47
        # self.client = AsyncAzureOpenAI(
48
        self.client = AzureOpenAI(
49
            azure_endpoint=params["endpoint"],
50
            api_key=os.getenv("AZURE_OPENAI_KEY"),
51
            api_version="2024-02-15-preview",
52
        )
53
54
    def _suggest_with_sources(
55
        self, texts: list[str], sources: list[tuple[str, float]]
56
    ) -> dict[str, SuggestionBatch]:
57
        return {
58
            project_id: self.project.registry.get_project(project_id).suggest(texts)
59
            for project_id, _ in sources
60
        }
61
62
63
class LLMBackend(BaseLLMBackend):
64
    # """TODO backend that combines results from multiple projects"""
65
66
    name = "llm"
67
68
    system_prompt = """
69
        You will be given text and a list of keywords to describe it. Your task is to
70
        decide whether a keyword is suitable for the text and describes it well:
71
        give output as a binary value; 1 for good keywords and 0 for keywords that do
72
        not describe the text. You must output JSON with keywords as field names and
73
        the binary scores as field values.
74
        There must be the same number of items in the JSON as there are in the
75
        intput keyword list, so give either 0 or 1 to every input keyword.
76
    """
77
78
    @property
79
    def is_trained(self) -> bool:
80
        sources_trained = self._get_sources_attribute("is_trained")
81
        return all(sources_trained)
82
83
    @property
84
    def modification_time(self) -> datetime | None:
85
        mtimes = self._get_sources_attribute("modification_time")
86
        return max(filter(None, mtimes), default=None)
87
88
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
89
        raise NotSupportedException("Training LLM backend is not possible.")
90
91
    def _suggest_batch(
92
        self, texts: list[str], params: dict[str, Any]
93
    ) -> SuggestionBatch:
94
        sources = annif.util.parse_sources(params["sources"])
95
        model = params["model"]
96
        llm_scores_weight = float(params["llm_scores_weight"])
97
        # llm_probs_weight = float(params["llm_probs_weight"])
98
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
99
100
        batch_results = []
101
        base_suggestion_batch = self._suggest_with_sources(texts, sources)[
102
            sources[0][0]
103
        ]
104
105
        for text, base_suggestions in zip(texts, base_suggestion_batch):
106
            text = self._truncate_text(text, encoding)
107
            prompt = "Here is the text:\n" + text + "\n"
108
109
            base_labels = [
110
                self.project.subjects[s.subject_id].labels["en"]
111
                for s in base_suggestions
112
            ]
113
            prompt += "And here are the keywords:\n" + "\n".join(base_labels)
114
            llm_result = self._call_llm(prompt, model)
115
            print(llm_result)
116
            # try:
117
            #     llm_result = json.loads(llm_labels)
118
            # except (TypeError, json.decoder.JSONDecodeError) as err:
119
            #     print(err)
120
            #     llm_result = dict()
121
            results = self._map_llm_suggestions(
122
                llm_result,
123
                base_labels,
124
                base_suggestions,
125
                llm_scores_weight,
126
            )
127
            batch_results.append(results)
128
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
129
130
    def _truncate_text(self, text, encoding):
131
        """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
132
        OpenAI tokenizer"""
133
134
        MAX_PROMPT_TOKENS = 14000
135
        tokens = encoding.encode(text)
136
        return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
137
138
    def _map_llm_suggestions(
139
        self,
140
        llm_result,
141
        base_labels,
142
        base_suggestions,
143
        llm_scores_weight,
144
    ):
145
        suggestions = []
146
        for blabel, bsuggestion in zip(base_labels, base_suggestions):
147
            try:
148
                score = llm_result[blabel]
149
            except KeyError:
150
                print(f"Base label {blabel} not found in LLM labels")
151
                score = bsuggestion.score  # use only base suggestion score
152
            subj_id = bsuggestion.subject_id
153
154
            base_scores_weight = 1.0 - llm_scores_weight
155
            mean_score = (
156
                base_scores_weight * bsuggestion.score + llm_scores_weight * score
157
            ) / (
158
                base_scores_weight + llm_scores_weight
159
            )  # weighted mean of LLM and base scores!
160
            suggestions.append(SubjectSuggestion(subject_id=subj_id, score=mean_score))
161
        return suggestions
162
163
    def _call_llm(self, prompt: str, model: str):
164
        messages = [
165
            {"role": "system", "content": self.system_prompt},
166
            {"role": "user", "content": prompt},
167
        ]
168
        try:
169
            completion = self.client.chat.completions.create(
170
                model=model,
171
                messages=messages,
172
                temperature=0.0,
173
                seed=0,
174
                max_tokens=1800,
175
                top_p=0.95,
176
                frequency_penalty=0,
177
                presence_penalty=0,
178
                stop=None,
179
                response_format={"type": "json_object"},
180
                logprobs=True,
181
                # top_logprobs=2,
182
            )
183
            logprobs_completion = completion.choices[0].logprobs.content
184
            return self._get_results(logprobs_completion)
185
        except BadRequestError as err:  # openai.RateLimitError
186
            print(err)
187
            return dict()
188
189
    def _get_results(self, logprobs_completion):
190
        # labels, probs = [], []
191
        results = dict()
192
        line = ""
193
        for token in logprobs_completion:
194
            # print("Token:", token.token)
195
            # print("Linear prob:", np.round(np.exp(token.logprob) * 100, 2), "%")
196
            # prev_linear_prob = np.exp(token.logprob)
197
            prev_token = token
198
199
            line += token.token
200
            if "\n" in token.token:
201
                print("Line is: " + line)
202
                label, boolean_score = self._parse_line(line)
203
                if not label == "<failed>":
204
                    # results[label] = prev_linear_prob
205
                    results[label] = self._get_score(prev_token)
206
                line = ""
207
        return results
208
209
    def _parse_line(self, line):
210
        try:
211
            label = line.split('"')[1]
212
            boolean_score = line.split(":")[1].strip().replace(",", "")
213
        except IndexError:
214
            print(f"Failed parsing line: '{line}'")
215
            return "<failed>"
216
        return label, boolean_score
217
218
    def _get_score(self, token):
219
        linear_prob = np.exp(token.logprob)
220
        if token.token == "1":
221
            return linear_prob
222
        elif token.token == "0":
223
            return 1.0 - linear_prob
224
        else:
225
            print(token)
226
            return None
227