Passed
Push — experiment-llm-rescoring-with-... ( 524821 )
by Juho
04:29
created

annif.backend.llm.LLMBackend._get_logprobs()   A

Complexity

Conditions 3

Size

Total Lines 42
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 17
nop 2
dl 0
loc 42
rs 9.55
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
from typing import TYPE_CHECKING, Any
8
9
import tiktoken
10
from openai import AzureOpenAI, BadRequestError
11
12
import annif.eval
13
import annif.parallel
14
import annif.util
15
from annif.exception import NotSupportedException
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend
19
20
# from openai import AsyncAzureOpenAI
21
22
23
if TYPE_CHECKING:
24
    from datetime import datetime
25
26
    from annif.corpus.document import DocumentCorpus
27
28
29
class BaseLLMBackend(backend.AnnifBackend):
30
    # """Base class for TODO backends"""
31
32
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
33
        params = self._get_backend_params(None)
34
        sources = annif.util.parse_sources(params["sources"])
35
        return [
36
            getattr(self.project.registry.get_project(project_id), attr)
37
            for project_id, _ in sources
38
        ]
39
40
    def initialize(self, parallel: bool = False) -> None:
41
        # initialize all the source projects
42
        params = self._get_backend_params(None)
43
        for project_id, _ in annif.util.parse_sources(params["sources"]):
44
            project = self.project.registry.get_project(project_id)
45
            project.initialize(parallel)
46
47
        # self.client = AsyncAzureOpenAI(
48
        self.client = AzureOpenAI(
49
            azure_endpoint=params["endpoint"],
50
            api_key=os.getenv("AZURE_OPENAI_KEY"),
51
            api_version="2024-02-15-preview",
52
        )
53
54
    def _suggest_with_sources(
55
        self, texts: list[str], sources: list[tuple[str, float]]
56
    ) -> dict[str, SuggestionBatch]:
57
        return {
58
            project_id: self.project.registry.get_project(project_id).suggest(texts)
59
            for project_id, _ in sources
60
        }
61
62
63
class LLMBackend(BaseLLMBackend):
64
    # """TODO backend that combines results from multiple projects"""
65
66
    name = "llm"
67
68
    system_prompt = """
69
        You will be given text and a list of keywords to describe it. Your task is to
70
        score the keywords with a value between 0.0 and 1.0. The score value
71
        should depend on how well the keyword represents the text: a perfect
72
        keyword should have score 1.0 and completely unrelated keyword score
73
        0.0. You must output JSON with keywords as field names and add their scores
74
        as field values.
75
        There must be the same number of objects in the JSON as there are lines in the
76
        intput keyword list; do not skip scoring any keywords.
77
    """
78
    # Give zero or very low score to the keywords that do not describe the text.
79
80
    @property
81
    def is_trained(self) -> bool:
82
        sources_trained = self._get_sources_attribute("is_trained")
83
        return all(sources_trained)
84
85
    @property
86
    def modification_time(self) -> datetime | None:
87
        mtimes = self._get_sources_attribute("modification_time")
88
        return max(filter(None, mtimes), default=None)
89
90
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
91
        raise NotSupportedException("Training LLM backend is not possible.")
92
93
    def _suggest_batch(
94
        self, texts: list[str], params: dict[str, Any]
95
    ) -> SuggestionBatch:
96
        sources = annif.util.parse_sources(params["sources"])
97
        model = params["model"]
98
        llm_scores_weight = float(params["llm_scores_weight"])
99
        # llm_probs_weight = float(params["llm_probs_weight"])
100
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
101
102
        batch_results = []
103
        base_suggestion_batch = self._suggest_with_sources(texts, sources)[
104
            sources[0][0]
105
        ]
106
107
        for text, base_suggestions in zip(texts, base_suggestion_batch):
108
            base_labels = [
109
                self.project.subjects[s.subject_id].labels["en"]
110
                for s in base_suggestions
111
            ]
112
            prompt = "Here are the keywords:\n" + "\n".join(base_labels) + "\n" * 3
113
114
            text = self._truncate_text(text, encoding)
115
            prompt += "Here is the text:\n" + text + "\n"
116
117
            answer = self._call_llm(prompt, model)
118
            print(answer)
119
            try:
120
                llm_result = json.loads(answer)
121
            except (TypeError, json.decoder.JSONDecodeError) as err:
122
                print(err)
123
                llm_result = dict()
124
            results = self._get_llm_suggestions(
125
                llm_result,
126
                base_labels,
127
                base_suggestions,
128
                llm_scores_weight,
129
            )
130
            batch_results.append(results)
131
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
132
133
    def _truncate_text(self, text, encoding):
134
        """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
135
        OpenAI tokenizer"""
136
137
        MAX_PROMPT_TOKENS = 14000
138
        tokens = encoding.encode(text)
139
        return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
140
141
    def _get_llm_suggestions(
142
        self,
143
        llm_result,
144
        base_labels,
145
        base_suggestions,
146
        llm_scores_weight,
147
    ):
148
        suggestions = []
149
        for blabel, bsuggestion in zip(base_labels, base_suggestions):
150
            try:
151
                score = llm_result[blabel]
152
            except KeyError:
153
                print(f"Base label {blabel} not found in LLM labels")
154
                score = bsuggestion.score  # use only base suggestion score
155
            subj_id = bsuggestion.subject_id
156
157
            base_scores_weight = 1.0 - llm_scores_weight
158
            mean_score = (
159
                base_scores_weight * bsuggestion.score
160
                + llm_scores_weight * score  # * probability * llm_probs_weight
161
            ) / (
162
                base_scores_weight
163
                + llm_scores_weight  # * probability * llm_probs_weight
164
            )  # weighted mean of LLM and base scores!
165
            suggestions.append(SubjectSuggestion(subject_id=subj_id, score=mean_score))
166
        return suggestions
167
168
    def _call_llm(self, prompt: str, model: str):
169
        messages = [
170
            {"role": "system", "content": self.system_prompt},
171
            {"role": "user", "content": prompt},
172
        ]
173
        # print(prompt) #[-10000:])
174
        try:
175
            completion = self.client.chat.completions.create(
176
                model=model,
177
                messages=messages,
178
                temperature=0.0,
179
                seed=0,
180
                max_tokens=1800,
181
                top_p=0.95,
182
                frequency_penalty=0,
183
                presence_penalty=0,
184
                stop=None,
185
                response_format={"type": "json_object"},
186
            )
187
188
            completion = completion.choices[0].message.content
189
            return completion
190
        except BadRequestError as err:  # openai.RateLimitError
191
            print(err)
192
            return "{}"
193