Passed
Push — experiment-hf-transformers-zer... ( 393cd7 )
by Juho
04:26
created

RescorerBackend.initialize()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 9
nop 2
dl 0
loc 11
rs 9.95
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import json
6
import os
7
from typing import TYPE_CHECKING, Any
8
9
# import tiktoken
10
from transformers import pipeline
11
12
import annif.eval
13
import annif.parallel
14
import annif.util
15
from annif.exception import NotSupportedException
16
from annif.suggestion import SubjectSuggestion, SuggestionBatch
17
18
from . import backend
19
20
# from openai import AsyncAzureOpenAI
21
22
23
if TYPE_CHECKING:
24
    from datetime import datetime
25
26
    from annif.corpus.document import DocumentCorpus
27
28
29
class RescorerBackend(backend.AnnifBackend):
30
    # """TODO backend that combines results from multiple projects"""
31
32
    name = "rescorer"
33
34
    def _get_sources_attribute(self, attr: str) -> list[bool | None]:
35
        params = self._get_backend_params(None)
36
        sources = annif.util.parse_sources(params["sources"])
37
        return [
38
            getattr(self.project.registry.get_project(project_id), attr)
39
            for project_id, _ in sources
40
        ]
41
42
    def initialize(self, parallel: bool = False) -> None:
43
        # initialize all the source projects
44
        params = self._get_backend_params(None)
45
        for project_id, _ in annif.util.parse_sources(params["sources"]):
46
            project = self.project.registry.get_project(project_id)
47
            project.initialize(parallel)
48
49
        self.classifier = pipeline(
50
            "zero-shot-classification", model=params.get("model"),
51
            from_pt=True,
52
            multi_label=True,
53
        )
54
55
    def _suggest_with_sources(
56
        self, texts: list[str], sources: list[tuple[str, float]]
57
    ) -> dict[str, SuggestionBatch]:
58
        return {
59
            project_id: self.project.registry.get_project(project_id).suggest(texts)
60
            for project_id, _ in sources
61
        }
62
63
    @property
64
    def is_trained(self) -> bool:
65
        sources_trained = self._get_sources_attribute("is_trained")
66
        return all(sources_trained)
67
68
    @property
69
    def modification_time(self) -> datetime | None:
70
        mtimes = self._get_sources_attribute("modification_time")
71
        return max(filter(None, mtimes), default=None)
72
73
    def _train(self, corpus: DocumentCorpus, params: dict[str, Any], jobs: int = 0):
74
        raise NotSupportedException("Training rescorer backend is not possible.")
75
76
    def _suggest_batch(
77
        self, texts: list[str], params: dict[str, Any]
78
    ) -> SuggestionBatch:
79
        sources = annif.util.parse_sources(params["sources"])
80
        new_scores_weight = float(params["new_scores_weight"])
81
        # llm_probs_weight = float(params["llm_probs_weight"])
82
        # encoding = tiktoken.encoding_for_model(model.rsplit("-", 1)[0])
83
84
        batch_results = []
85
        base_suggestion_batch = self._suggest_with_sources(texts, sources)[
86
            sources[0][0]
87
        ]
88
89
        from time import time
90
        start_t = time()
91
        for text, base_suggestions in zip(texts, base_suggestion_batch):
92
            base_labels = [
93
                self.project.subjects[s.subject_id].labels["en"]
94
                for s in base_suggestions
95
            ]
96
97
            # text = self._truncate_text(text, encoding)
98
            result = self.classifier(text, base_labels)
99
            print(result)
100
            # try:
101
            #    llm_result = json.loads(answer)
102
            # except (TypeError, json.decoder.JSONDecodeError) as err:
103
            #    print(err)
104
            #    llm_result = dict()
105
            rescored_results = self._rescore_suggestions(
106
                result,
107
                base_labels,
108
                base_suggestions,
109
                new_scores_weight,
110
            )
111
            batch_results.append(rescored_results)
112
        print(f"Time: {time() - start_t:.2f} s")
113
        return SuggestionBatch.from_sequence(batch_results, self.project.subjects)
114
115
    # def _truncate_text(self, text, encoding):
116
    #     """truncate text so it contains at most MAX_PROMPT_TOKENS according to the
117
    #     OpenAI tokenizer"""
118
119
    #     MAX_PROMPT_TOKENS = 14000
120
    #     tokens = encoding.encode(text)
121
    #     return encoding.decode(tokens[:MAX_PROMPT_TOKENS])
122
123
    def _rescore_suggestions(
124
        self,
125
        result,
126
        base_labels,
127
        base_suggestions,
128
        new_scores_weight,
129
    ):
130
        suggestions = []
131
        for blabel, bsuggestion in zip(base_labels, base_suggestions):
132
            try:
133
                ind = result["labels"].index(blabel)
134
                score = result["scores"][ind]
135
            except ValueError:
136
                print(f"Base label {blabel} not found in new labels")
137
                score = bsuggestion.score  # use only base suggestion score
138
            subj_id = bsuggestion.subject_id
139
140
            base_scores_weight = 1.0 - new_scores_weight
141
            mean_score = (
142
                base_scores_weight * bsuggestion.score
143
                + new_scores_weight * score  # * probability * llm_probs_weight
144
            ) / (
145
                base_scores_weight
146
                + new_scores_weight  # * probability * llm_probs_weight
147
            )  # weighted mean of LLM and base scores!
148
            suggestions.append(SubjectSuggestion(subject_id=subj_id, score=mean_score))
149
        return suggestions
150