Passed
Push — experiment-llm-ensemble-backen... ( 1e9159...150dab )
by Juho
02:55
created

LLMEnsembleBackend._get_labels_batch()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
"""Language model based ensemble backend that combines results from multiple
2
projects."""
3
4
from __future__ import annotations
5
6
import json
7
import os
8
from typing import TYPE_CHECKING, Any
9
10
import tiktoken
11
from openai import AzureOpenAI, BadRequestError
12
13
import annif.eval
14
import annif.parallel
15
import annif.util
16
from annif.exception import NotSupportedException
17
from annif.suggestion import SubjectSuggestion, SuggestionBatch
18
19
from . import backend, ensemble
20
21
# from openai import AsyncAzureOpenAI
22
23
24
if TYPE_CHECKING:
25
    from annif.corpus.document import DocumentCorpus
26
27
28
class BaseLLMBackend(backend.AnnifBackend):
29
    # """Base class for TODO backends"""
30
31
    def initialize(self, parallel: bool = False) -> None:
32
        # initialize all the source projects
33
        params = self._get_backend_params(None)
34
        for project_id, _ in annif.util.parse_sources(params["sources"]):
35
            project = self.project.registry.get_project(project_id)
36
            project.initialize(parallel)
37
38
        # self.client = AsyncAzureOpenAI(
39
        self.client = AzureOpenAI(
40
            azure_endpoint=params["endpoint"],
41
            api_key=os.getenv("AZURE_OPENAI_KEY"),
42
            api_version="2024-02-15-preview",
43
        )
44
45
46
class LLMEnsembleBackend(BaseLLMBackend, ensemble.BaseEnsembleBackend):
47
    # """TODO backend that combines results from multiple projects"""
48
49
    name = "llm_ensemble"
50
51
    system_prompt = """
52
        You will be given text and a list of keywords to describe it. Your task is to
53
        score the keywords with a value between 0.0 and 1.0. The score value
54
        should depend on how well the keyword represents the text: a perfect
55
        keyword should have score 1.0 and completely unrelated keyword score
56
        0.0. You must output JSON with keywords as field names and add their scores
57
        as field values.
58
        There must be the same number of objects in the JSON as there are lines in the
59
        intput keyword list; do not skip scoring any keywords.
60
    """
61
    # Give zero or very low score to the keywords that do not describe the text.
62
63
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
64
        raise NotSupportedException("Training LM ensemble backend is not possible.")
65
66
    def _suggest_batch(
67
        self, texts: list[str], params: dict[str, Any]
68
    ) -> SuggestionBatch:
69
        sources = annif.util.parse_sources(params["sources"])
70
        batch_by_source = self._suggest_with_sources(texts, sources)
71
        merged_source_batch = self._merge_source_batches(
72
            batch_by_source, sources, params
73
        )
74
75
        # Add LLM suggestions to the source batches
76
        batch_by_source[self.project.project_id] = self._llm_suggest_batch(
77
            texts, merged_source_batch, params
78
        )
79
        new_sources = sources + [(self.project.project_id, float(params["llm_weight"]))]
80
        return self._merge_source_batches(batch_by_source, new_sources, params)
81
82
    def _llm_suggest_batch(
83
        self,
84
        texts: list[str],
85
        suggestion_batch: SuggestionBatch,
86
        params: dict[str, Any],
87
    ) -> SuggestionBatch:
88
        model = params["model"]
89
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
90
        labels_batch = self._get_labels_batch(suggestion_batch)
91
92
        llm_batch_suggestions = []
93
        for text, labels in zip(texts, labels_batch):
94
            prompt = "Here are the keywords:\n" + "\n".join(labels) + "\n" * 3
95
            text = self._truncate_text(text, encoding)
96
            prompt += "Here is the text:\n" + text + "\n"
97
98
            response = self._call_llm(prompt, model)
99
            try:
100
                llm_result = json.loads(response)
101
            except (TypeError, json.decoder.JSONDecodeError) as err:
102
                print(err)
103
                llm_result = None
104
                continue  # TODO: handle this error
105
            llm_suggestions = [
106
                SubjectSuggestion(
107
                    subject_id=self.project.subjects.by_label(llm_label, "en"),
108
                    score=score,
109
                )
110
                for llm_label, score in llm_result.items()
111
            ]
112
            llm_batch_suggestions.append(llm_suggestions)
113
        return SuggestionBatch.from_sequence(
114
            llm_batch_suggestions,
115
            self.project.subjects,
116
        )
117
118
    def _get_labels_batch(self, suggestion_batch: SuggestionBatch) -> list[list[str]]:
119
        return [
120
            [
121
                self.project.subjects[suggestion.subject_id].labels[
122
                    "en"
123
                ]  # TODO: make language selectable
124
                for suggestion in suggestion_result
125
            ]
126
            for suggestion_result in suggestion_batch
127
        ]
128
129
    def _truncate_text(self, text, encoding):
130
        """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
131
        OpenAI tokenizer"""
132
133
        MAX_PROMPT_TOKENS = 14000
134
        tokens = encoding.encode(text)
135
        return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
136
137
    def _call_llm(self, prompt: str, model: str):
138
        messages = [
139
            {"role": "system", "content": self.system_prompt},
140
            {"role": "user", "content": prompt},
141
        ]
142
        try:
143
            completion = self.client.chat.completions.create(
144
                model=model,
145
                messages=messages,
146
                temperature=0.0,
147
                seed=0,
148
                max_tokens=1800,
149
                top_p=0.95,
150
                frequency_penalty=0,
151
                presence_penalty=0,
152
                stop=None,
153
                response_format={"type": "json_object"},
154
            )
155
156
            completion = completion.choices[0].message.content
157
            return completion
158
        except BadRequestError as err:  # openai.RateLimitError
159
            print(err)
160
            return "{}"
161