Passed
Push — experiment-llm-rescoring ( f86252...dcbcd9 )
by Juho
02:51
created

annif.backend.llm.LLMBackend._train()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 4
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
from typing import TYPE_CHECKING, Any
8
9
from openai import AzureOpenAI, BadRequestError
10
11
import annif.eval
12
import annif.parallel
13
import annif.util
14
from annif.exception import NotSupportedException
15
from annif.suggestion import SubjectSuggestion, SuggestionBatch
16
17
from . import backend
18
19
# from openai import AsyncAzureOpenAI
20
21
22
if TYPE_CHECKING:
23
    from datetime import datetime
24
25
    from annif.corpus.document import DocumentCorpus
26
27
28
class BaseLLMBackend(backend.AnnifBackend):
29
    # """Base class for TODO backends"""
30
31
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
32
        params = self._get_backend_params(None)
33
        sources = annif.util.parse_sources(params["sources"])
34
        return [
35
            getattr(self.project.registry.get_project(project_id), attr)
36
            for project_id, _ in sources
37
        ]
38
39
    def initialize(self, parallel: bool = False) -> None:
40
        # initialize all the source projects
41
        params = self._get_backend_params(None)
42
        for project_id, _ in annif.util.parse_sources(params["sources"]):
43
            project = self.project.registry.get_project(project_id)
44
            project.initialize(parallel)
45
46
        # self.client = AsyncAzureOpenAI(
47
        self.client = AzureOpenAI(
48
            azure_endpoint=params["endpoint"],
49
            api_key=os.getenv("AZURE_OPENAI_KEY"),
50
            api_version="2024-02-15-preview",
51
        )
52
53
    def _suggest_with_sources(
54
        self, texts: list[str], sources: list[tuple[str, float]]
55
    ) -> dict[str, SuggestionBatch]:
56
        return {
57
            project_id: self.project.registry.get_project(project_id).suggest(texts)
58
            for project_id, _ in sources
59
        }
60
61
62
class LLMBackend(BaseLLMBackend):
63
    # """TODO backend that combines results from multiple projects"""
64
65
    name = "llm"
66
67
    system_prompt = """
68
        You will be given text and a list of keywords to describe it. Your task is to
69
        score the keywords with a value between 0.0 and 1.0. The score value
70
        should depend on how well the keyword represents the text: a perfect
71
        keyword should have score 1.0 and completely unrelated keyword score
72
        0.0. You must output JSON with keywords as field names and add their scores
73
        as field values.
74
        There must be the same number of items in the JSON as there are in the
75
        intput keyword list.
76
    """
77
78
    @property
79
    def is_trained(self) -> bool:
80
        sources_trained = self._get_sources_attribute("is_trained")
81
        return all(sources_trained)
82
83
    @property
84
    def modification_time(self) -> datetime | None:
85
        mtimes = self._get_sources_attribute("modification_time")
86
        return max(filter(None, mtimes), default=None)
87
88
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
89
        raise NotSupportedException("Training LLM backend is not possible.")
90
91
    def _suggest_batch(
92
        self, texts: list[str], params: dict[str, Any]
93
    ) -> SuggestionBatch:
94
        sources = annif.util.parse_sources(params["sources"])
95
        model = params["model"]
96
        chars_max = 40000
97
98
        batch_results = []
99
        base_suggestion_batch = self._suggest_with_sources(texts, sources)[
100
            sources[0][0]
101
        ]
102
103
        for text, base_suggestions in zip(texts, base_suggestion_batch):
104
            prompt = "Here is the text:\n" + text[:chars_max] + "\n"
105
106
            base_labels = [
107
                self.project.subjects[s.subject_id].labels["en"]
108
                for s in base_suggestions
109
            ]
110
            prompt += "And here are the keywords:\n" + "\n".join(base_labels)
111
            answer, weights = self._call_llm(prompt, model)
112
            print(answer)
113
            print(weights)
114
            try:
115
                llm_result = json.loads(answer)
116
            except TypeError as err:
117
                print(err)
118
                llm_result = dict()
119
            results = self._get_llm_suggestions(
120
                llm_result, base_labels, base_suggestions, weights
121
            )
122
            batch_results.append(results)
123
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
124
125
    def _get_llm_suggestions(self, llm_result, base_labels, base_suggestions, weights):
126
        suggestions = []
127
        # print(f"LLM result: {llm_result}")
128
        for blabel, bsuggestion in zip(base_labels, base_suggestions):
129
            # score = llm_result.get(blabel, 0)
130
            try:
131
                score = llm_result[blabel]
132
                weight = weights[blabel]
133
            except KeyError:
134
                print(f"Base label {blabel} not found in LLM labels")
135
                score = 0.0  # bsuggestion.score
136
                weight = 0.0
137
            subj_id = bsuggestion.subject_id
138
            # mean_score = (bsuggestion.score + score) / 2  # Mean of scores
139
            mean_score = (bsuggestion.score + weight * score) / (
140
                1 + weight
141
            )  # weighted mean of LLM and base scores!
142
            suggestions.append(SubjectSuggestion(subject_id=subj_id, score=mean_score))
143
        return suggestions
144
145
    # async def _call_llm(self, prompt: str, model: str):
146
    def _call_llm(self, prompt: str, model: str):
147
        messages = [
148
            {"role": "system", "content": self.system_prompt},
149
            {"role": "user", "content": prompt},
150
        ]
151
        try:
152
            # completion = await client.chat.completions.create(
153
            completion = self.client.chat.completions.create(
154
                model=model,
155
                messages=messages,
156
                temperature=0.0,
157
                seed=0,
158
                max_tokens=1800,
159
                top_p=0.95,
160
                frequency_penalty=0,
161
                presence_penalty=0,
162
                stop=None,
163
                response_format={"type": "json_object"},
164
                logprobs=True,
165
            )
166
            # return completion.choices[0].message.content
167
168
            lines = self._get_logprobs(completion.choices[0].logprobs.content)
169
            answer = completion.choices[0].message.content
170
            probs = self._get_probs(lines)
171
            return answer, probs
172
        except BadRequestError as err:
173
            print(err)
174
            return "{}"
175
176
    def _get_logprobs(self, content):
177
        import numpy as np
178
179
        lines = []
180
        joint_logprob = 0.0
181
        line = ""
182
        line_joint_logprob = 0.0
183
        for token in content:
184
            # print("Token:", token.token)
185
            # print("Log prob:", token.logprob)
186
            # print("Linear prob:", np.round(np.exp(token.logprob) * 100, 2), "%")
187
            # print("Bytes:", token.bytes, "\n")
188
            # aggregated_bytes += token.bytes
189
            joint_logprob += token.logprob
190
191
            line += token.token
192
            line_joint_logprob += token.logprob
193
            if "\n" in token.token:
194
                # print("Line is: "+ line)
195
                line_prob = np.exp(line_joint_logprob)
196
                # print("Line's linear prob:",  np.round(line_prob * 100, 2), "%")
197
198
                lines.append((line, line_prob))
199
                line = ""
200
                line_joint_logprob = 0.0
201
        #         print()
202
        # print()
203
        # print("Joint log prob:", joint_logprob)
204
        # print("Joint prob:", np.round(np.exp(joint_logprob) * 100, 2), "%")
205
        return lines
206
207
    def _get_probs(self, lines):
208
        probs = dict()
209
        for line, prob in lines:
210
            try:
211
                label = line.split('"')[1]
212
            except IndexError:
213
                print("Failed parsing line: " + line)
214
                continue  # Not a line with label
215
            # probs[label] = 1.0
216
            probs[label] = prob
217
        return probs
218