Passed
Push — experiment-llm-keyword-extract... ( cfc7a0 )
by Juho
05:00
created

annif.backend.llm.LLMBackend._train()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 4
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
import re
8
from collections import defaultdict
9
from typing import TYPE_CHECKING, Any
10
11
import joblib
12
import tiktoken
13
from openai import AzureOpenAI, BadRequestError
14
from rdflib.namespace import SKOS
15
16
import annif.eval
17
import annif.parallel
18
import annif.util
19
from annif.exception import ConfigurationException, NotSupportedException
20
from annif.suggestion import SubjectSuggestion, SuggestionBatch
21
22
from . import backend
23
24
# from openai import AsyncAzureOpenAI
25
26
27
if TYPE_CHECKING:
28
    from datetime import datetime
29
30
    from rdflib.term import URIRef
31
32
    from annif.corpus.document import DocumentCorpus
33
34
35
class BaseLLMBackend(backend.AnnifBackend):
36
    # """Base class for TODO backends"""
37
38
    def initialize(self, parallel: bool = False) -> None:
39
        # initialize all the source projects
40
        params = self._get_backend_params(None)
41
42
        # self.client = AsyncAzureOpenAI(
43
        self.client = AzureOpenAI(
44
            azure_endpoint=params["endpoint"],
45
            api_key=os.getenv("AZURE_OPENAI_API_KEY"),
46
            api_version="2024-02-15-preview",
47
        )
48
        self._initialize_index()
49
50
51
class LLMBackend(BaseLLMBackend, backend.AnnifBackend):
52
    # """TODO backend that combines results from multiple projects"""
53
54
    name = "llm"
55
    # defaults for uninitialized instances
56
    _index = None
57
    INDEX_FILE = "llm-index"
58
59
    DEFAULT_PARAMETERS = {
60
        "label_types": ["prefLabel", "altLabel"],
61
        "remove_parentheses": False,
62
    }
63
64
    system_prompt = """
65
        You are a professional subject indexer.
66
        You will be given a text. Your task is to give a list of keywords to describe
67
        the text along scores for the keywords with a value between 0.0 and 1.0. The
68
        score value should depend on how well the keyword represents the text: a perfect
69
        keyword should have score 1.0 and completely unrelated keyword score
70
        0.0. You must output JSON with keywords as field names and add their scores
71
        as field values.
72
    """
73
    # Give zero or very low score to the keywords that do not describe the text.
74
75
    @property
76
    def is_trained(self) -> bool:
77
        True
78
79
    @property
80
    def modification_time(self) -> datetime | None:
81
        None
82
83
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
84
        raise NotSupportedException("Training LLM backend is not possible.")
85
86
    @property
87
    def label_types(self) -> list[URIRef]:
88
        if isinstance(self.params["label_types"], str):  # Label types set by user
89
            label_types = [lt.strip() for lt in self.params["label_types"].split(",")]
90
            self._validate_label_types(label_types)
91
        else:
92
            label_types = self.params["label_types"]  # The defaults
93
        return [getattr(SKOS, lt) for lt in label_types]
94
95
    def _validate_label_types(self, label_types: list[str]) -> None:
96
        for lt in label_types:
97
            if lt not in ("prefLabel", "altLabel", "hiddenLabel"):
98
                raise ConfigurationException(
99
                    f"invalid label type {lt}", backend_id=self.backend_id
100
                )
101
102 View Code Duplication
    def _initialize_index(self) -> None:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
103
        if self._index is None:
104
            path = os.path.join(self.datadir, self.INDEX_FILE)
105
            if os.path.exists(path):
106
                self._index = joblib.load(path)
107
                self.debug(f"Loaded index from {path} with {len(self._index)} labels")
108
            else:
109
                self.info("Creating index")
110
                self._index = self._create_index()
111
                self._save_index(path)
112
                self.info(f"Created index with {len(self._index)} labels")
113
114
    def _save_index(self, path: str) -> None:
115
        annif.util.atomic_save(
116
            self._index, self.datadir, self.INDEX_FILE, method=joblib.dump
117
        )
118
119 View Code Duplication
    def _create_index(self) -> dict[str, set[str]]:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
120
        index = defaultdict(set)
121
        skos_vocab = self.project.vocab.skos
122
        for concept in skos_vocab.concepts:
123
            uri = str(concept)
124
            labels_by_lang = skos_vocab.get_concept_labels(concept, self.label_types)
125
            for label in labels_by_lang[self.params["language"]]:
126
                # label = self._normalize_label(label)
127
                index[label].add(uri)
128
        index.pop("", None)  # Remove possible empty string entry
129
        return dict(index)
130
131
    def _suggest(self, text: str, params: dict[str, Any]) -> SuggestionBatch:
132
        model = params["model"]
133
        limit = int(params["limit"])
134
135
        encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
136
137
        text = self._truncate_text(text, encoding)
138
        prompt = "Here is the text:\n" + text + "\n"
139
140
        answer = self._call_llm(prompt, model)
141
        try:
142
            llm_result = json.loads(answer)
143
        except (TypeError, json.decoder.JSONDecodeError) as err:
144
            print(err)
145
            llm_result = dict()
146
147
        keyphrases = [(kp, score) for kp, score in llm_result.items()]
148
        suggestions = self._keyphrases2suggestions(keyphrases)
149
150
        subject_suggestions = [
151
            SubjectSuggestion(subject_id=self.project.subjects.by_uri(uri), score=score)
152
            for uri, score in suggestions[:limit]
153
            if score > 0.0
154
        ]
155
        return subject_suggestions
156
157
    def _truncate_text(self, text, encoding):
158
        """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
159
        OpenAI tokenizer"""
160
161
        MAX_PROMPT_TOKENS = 14000
162
        tokens = encoding.encode(text)
163
        return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
164
165
    def _call_llm(self, prompt: str, model: str):
166
        messages = [
167
            {"role": "system", "content": self.system_prompt},
168
            {"role": "user", "content": prompt},
169
        ]
170
        # print(prompt) #[-10000:])
171
        try:
172
            completion = self.client.chat.completions.create(
173
                model=model,
174
                messages=messages,
175
                temperature=0.0,
176
                seed=0,
177
                max_tokens=1800,
178
                top_p=0.95,
179
                frequency_penalty=0,
180
                presence_penalty=0,
181
                stop=None,
182
                response_format={"type": "json_object"},
183
            )
184
185
            completion = completion.choices[0].message.content
186
            return completion
187
        except BadRequestError as err:  # openai.RateLimitError
188
            print(err)
189
            return "{}"
190
191 View Code Duplication
    def _keyphrases2suggestions(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
192
        self, keyphrases: list[tuple[str, float]]
193
    ) -> list[tuple[str, float]]:
194
        suggestions = []
195
        not_matched = []
196
        for kp, score in keyphrases:
197
            uris = self._keyphrase2uris(kp)
198
            for uri in uris:
199
                suggestions.append((uri, score))
200
            if not uris:
201
                not_matched.append((kp, score))
202
        # Remove duplicate uris, conflating the scores
203
        suggestions = self._combine_suggestions(suggestions)
204
        self.debug(
205
            "Keyphrases not matched:\n"
206
            + "\t".join(
207
                [
208
                    kp[0] + " " + str(kp[1])
209
                    for kp in sorted(not_matched, reverse=True, key=lambda kp: kp[1])
210
                ]
211
            )
212
        )
213
        return suggestions
214
215
    def _keyphrase2uris(self, keyphrase: str) -> set[str]:
216
        keyphrase = self._normalize_phrase(keyphrase)
217
        keyphrase = self._sort_phrase(keyphrase)
218
        return self._index.get(keyphrase, [])
219
220
    def _normalize_label(self, label: str) -> str:
221
        label = str(label)
222
        if annif.util.boolean(self.params["remove_parentheses"]):
223
            label = re.sub(r" \(.*\)", "", label)
224
        normalized_label = self._normalize_phrase(label)
225
        return self._sort_phrase(normalized_label)
226
227
    def _normalize_phrase(self, phrase: str) -> str:
228
        return " ".join(self.project.analyzer.tokenize_words(phrase, filter=False))
229
230
    def _sort_phrase(self, phrase: str) -> str:
231
        words = phrase.split()
232
        return " ".join(sorted(words))
233
234 View Code Duplication
    def _combine_suggestions(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
235
        self, suggestions: list[tuple[str, float]]
236
    ) -> list[tuple[str, float]]:
237
        combined_suggestions = {}
238
        for uri, score in suggestions:
239
            if uri not in combined_suggestions:
240
                combined_suggestions[uri] = score
241
            else:
242
                old_score = combined_suggestions[uri]
243
                combined_suggestions[uri] = self._combine_scores(score, old_score)
244
        return list(combined_suggestions.items())
245
246
    def _combine_scores(self, score1: float, score2: float) -> float:
247
        # The result is never smaller than the greater input
248
        score1 = score1 / 2 + 0.5
249
        score2 = score2 / 2 + 0.5
250
        confl = score1 * score2 / (score1 * score2 + (1 - score1) * (1 - score2))
251
        return (confl - 0.5) * 2
252