Passed
Push — experiment-llm-rescoring ( f86252 )
by Juho
04:01
created

annif.backend.llm.LLMBackend._suggest_batch()   A

Complexity

Conditions 2

Size

Total Lines 28
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 20
nop 3
dl 0
loc 28
rs 9.4
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import os
6
from typing import TYPE_CHECKING, Any
7
8
from openai import AzureOpenAI
9
10
import annif.eval
11
import annif.parallel
12
import annif.util
13
from annif.exception import NotSupportedException
14
from annif.suggestion import SubjectSuggestion, SuggestionBatch
15
16
from . import backend
17
18
if TYPE_CHECKING:
19
    from datetime import datetime
20
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class BaseLLMBackend(backend.AnnifBackend):
25
    # """Base class for TODO backends"""
26
27
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
28
        params = self._get_backend_params(None)
29
        sources = annif.util.parse_sources(params["sources"])
30
        return [
31
            getattr(self.project.registry.get_project(project_id), attr)
32
            for project_id, _ in sources
33
        ]
34
35
    def initialize(self, parallel: bool = False) -> None:
36
        # initialize all the source projects
37
        params = self._get_backend_params(None)
38
        for project_id, _ in annif.util.parse_sources(params["sources"]):
39
            project = self.project.registry.get_project(project_id)
40
            project.initialize(parallel)
41
42
    def _suggest_with_sources(
43
        self, texts: list[str], sources: list[tuple[str, float]]
44
    ) -> dict[str, SuggestionBatch]:
45
        return {
46
            project_id: self.project.registry.get_project(project_id).suggest(texts)
47
            for project_id, _ in sources
48
        }
49
50
    def _suggest_batch(
51
        self, texts: list[str], params: dict[str, Any]
52
    ) -> SuggestionBatch:
53
        sources = annif.util.parse_sources(params["sources"])
54
        return self._suggest_with_sources(texts, sources)[sources[0][0]]
55
        # return self._merge_source_batches(batch_by_source, sources, params)
56
57
58
class LLMBackend(BaseLLMBackend):
59
    # """TODO backend that combines results from multiple projects"""
60
61
    name = "llm"
62
63
    # client = AzureOpenAI(
64
    #     azure_endpoint="",
65
    #     api_key=os.getenv("AZURE_OPENAI_KEY"),
66
    #     api_version="2024-02-15-preview",
67
    # )
68
69
    prompt_base = """
70
        I will give you text and some keywords to describe it. Your task is to
71
        score to the keywords with a value between 0.0 and 1.0, a perfect
72
        keyword should have score 1.0 and completely unrelated keyword score
73
        0.0. Output the same list of keywords and add its score separeted with
74
        comma, no other output or explanations.
75
    """
76
77
    @property
78
    def is_trained(self) -> bool:
79
        sources_trained = self._get_sources_attribute("is_trained")
80
        return all(sources_trained)
81
82
    @property
83
    def modification_time(self) -> datetime | None:
84
        mtimes = self._get_sources_attribute("modification_time")
85
        return max(filter(None, mtimes), default=None)
86
87
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
88
        raise NotSupportedException("Training LLM backend is not possible.")
89
90
    def _suggest_batch(
91
        self, texts: list[str], params: dict[str, Any]
92
    ) -> SuggestionBatch:
93
        sources = annif.util.parse_sources(params["sources"])
94
        endpoint = params["endpoint"]
95
        model = params["model"]
96
97
        batch_results = []
98
        base_suggestion_batch = self._suggest_with_sources(texts, sources)[
99
            sources[0][0]
100
        ]
101
102
        for text, base_suggestions in zip(texts, base_suggestion_batch):
103
            prompt = self.prompt_base + "\n" + "Here is the text:\n" + text + "\n"
104
105
            base_labels = [
106
                self.project.subjects[s.subject_id].labels["en"]
107
                for s in base_suggestions
108
            ]
109
            prompt += "And here are the keywords:\n" + "\n".join(base_labels)
110
111
            answer = self._call_llm(prompt, endpoint, model)
112
            llm_result = self._parse_llm_answer(answer)
113
            results = self._get_llm_suggestions(
114
                llm_result, base_labels, base_suggestions
115
            )
116
            batch_results.append(results)
117
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
118
119
    def _parse_llm_answer(self, answer):
120
        if not answer:
121
            return [], []
122
        labels, scores = [], []
123
        lines = answer.splitlines()
124
        for line in lines:
125
            parts = line.split(",")
126
            if len(parts) == 2:
127
                labels.append(parts[0])
128
                scores.append(float(parts[1]))
129
            else:
130
                print(f"Failed parsing line: {line.strip()}")
131
        return (labels, scores)
132
133
    def _get_llm_suggestions(self, llm_result, base_labels, base_suggestions):
134
        suggestions = []
135
        for label, score in zip(*llm_result):
136
            for blabel, bsuggestion in zip(base_labels, base_suggestions):
137
                if blabel == label:
138
                    subj_id = bsuggestion.subject_id
139
                    suggestions.append(
140
                        SubjectSuggestion(subject_id=subj_id, score=score)
141
                    )
142
        return suggestions
143
144
    def _call_llm(self, prompt: str, endpoint: str, model: str):
145
146
        client = AzureOpenAI(
147
            azure_endpoint=endpoint,
148
            api_key=os.getenv("AZURE_OPENAI_KEY"),
149
            api_version="2024-02-15-preview",
150
        )
151
152
        messages = [
153
            # {"role": "system", "content": "You are a helpful assistant."},
154
            {"role": "user", "content": prompt},
155
        ]
156
        completion = client.chat.completions.create(
157
            model=model,
158
            messages=messages,
159
            temperature=0.0,
160
            max_tokens=1800,
161
            top_p=0.95,
162
            frequency_penalty=0,
163
            presence_penalty=0,
164
            stop=None,
165
        )
166
        return completion.choices[0].message.content
167