Passed
Push — experiment-embeddings-backend ( 349504...3bcf99 )
by Juho
03:03
created

EmbeddingsBackend._combine_search_results()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import os.path
6
from typing import TYPE_CHECKING, Any
7
8
import numpy as np
9
import tiktoken
10
from openai import AzureOpenAI  # Try using huggingface client
11
from qdrant_client import QdrantClient, models
12
13
from annif.exception import NotInitializedException, NotSupportedException
14
from annif.suggestion import vector_to_suggestions
15
16
from . import backend
17
18
if TYPE_CHECKING:
19
    from collections.abc import Iterator
20
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class Vectorizer:
25
    def __init__(self, endpoint, model):
26
        self.model = model
27
        self.client = AzureOpenAI(  # TODO Try AsyncAzureOpenAI(
28
            azure_endpoint=endpoint,
29
            api_key=os.getenv("AZURE_OPENAI_KEY"),
30
            api_version="2024-02-15-preview",
31
        )
32
33
    def transform(self, text):
34
        response = self.client.embeddings.create(
35
            input=text,
36
            model=self.model,
37
            # dimensions=dimensions,  # TODO Try with reduced dimensions
38
        )
39
        return response.data[0].embedding
40
41
42
class EmbeddingsBackend(backend.AnnifBackend):
43
    """TODO xxx cector space similarity based backend for Annif"""
44
45
    name = "embeddings"
46
    is_trained = True
47
48
    # defaults for uninitialized instances
49
    _index = None
50
51
    DB_FILE = "qdrant-db"
52
    COLLECTION_NAME = "index-collection"
53
    BASE_MODEL = "text-embedding-3-large"
54
    VECTOR_DIMENSIONS = 3072  # For text-embedding-3-large
55
    MAX_TOKENS = 8192  # For text-embedding-3-large
56
57
    encoding = tiktoken.encoding_for_model(BASE_MODEL)
58
59
    def _initialize_index(self) -> None:
60
        if self._index is None:
61
            path = os.path.join(self.datadir, self.DB_FILE)
62
            self.debug("loading similarity index from {}".format(path))
63
            if os.path.exists(path):
64
                self.qdclient = QdrantClient(path=path)
65
            else:
66
                raise NotInitializedException(
67
                    "similarity index {} not found".format(path),
68
                    backend_id=self.backend_id,
69
                )
70
71
    def initialize(
72
        self,
73
    ) -> None:
74
        self.vectorizer = Vectorizer(self.params["endpoint"], self.params["model"])
75
        self._initialize_index()
76
77
    def _create_index(self, corpus) -> None:
78
        self.vectorizer = Vectorizer(self.params["endpoint"], self.params["model"])
79
        self.info("creating similarity index")
80
        path = os.path.join(self.datadir, self.DB_FILE)
81
82
        self.qdclient = QdrantClient(path=path)
83
        self.qdclient.recreate_collection(
84
            collection_name=self.COLLECTION_NAME,
85
            vectors_config=models.VectorParams(
86
                size=self.VECTOR_DIMENSIONS,
87
                distance=models.Distance.COSINE,
88
            ),
89
        )
90
91
        veccorpus = (
92
            (
93
                doc.subject_set,
94
                self.vectorizer.transform(self._truncate_text(" ".join(doc.text))),
95
            )
96
            for doc in corpus.documents
97
        )
98
99
        subject_sets, vectors = zip(*veccorpus)
100
        payloads = [{"subjects": [sid for sid in ss]} for ss in subject_sets]
101
        ids = list(range(len(vectors)))
102
        self.qdclient.upsert(
103
            collection_name=self.COLLECTION_NAME,
104
            points=models.Batch(
105
                ids=ids,
106
                vectors=vectors,
107
                payloads=payloads,
108
            ),
109
        )
110
        print(self.qdclient.get_collection(collection_name=self.COLLECTION_NAME))
111
112
    def _truncate_text(self, text):
113
        """truncate text so it contains at most MAX_TOKENS according to the OpenAI
114
        tokenizer"""
115
        tokens = self.encoding.encode(text)
116
        return self.encoding.decode(tokens[: self.MAX_TOKENS])
117
118
    def _train(
119
        self,
120
        corpus: DocumentCorpus,
121
        params: dict[str, Any],
122
        jobs: int = 0,
123
    ) -> None:
124
        if corpus == "cached":
125
            raise NotSupportedException(
126
                "Training embeddings project from cached data not supported."
127
            )
128
        if corpus.is_empty():
129
            raise NotSupportedException(
130
                "Cannot train embeddings project with no documents"
131
            )
132
        self.info("transforming subject corpus")
133
        self._create_index(corpus)
134
135
    def _suggest(self, text: str, params: dict[str, Any]) -> Iterator:
136
        self.debug(
137
            'Suggesting subjects for text "{}..." (len={})'.format(text[:20], len(text))
138
        )
139
        truncated_text = self._truncate_text(" ".join(text))
140
        vector = self.vectorizer.transform(truncated_text)
141
        # print(vector[:5])
142
        info = self.qdclient.get_collection(collection_name=self.COLLECTION_NAME)
143
        self.debug(f"Collection info: {info}")
144
        results = self._search(vector, params)
145
        # print(results)
146
        return self._prediction_to_result(results, params)
147
148
    def _search(self, vector, params):
149
        result = self.qdclient.search(
150
            collection_name=self.COLLECTION_NAME,
151
            query_vector=vector,
152
            # score_threshold=1.0,  # TODO parameterize this
153
            limit=int(params["limit"]),
154
            # search_params=models.SearchParams(hnsw_ef=128, exact=False),  # TODO This
155
        )
156
        return [(sp.payload["subjects"], sp.score) for sp in result]
157
158
    def _combine_search_results(self, results):
159
        combined = []
160
        for res in results:
161
            sids, weight = res[0], res[1]
162
            combined.extend([sid * weight for sid in sids])
163
        return combined
164
165
    # From backend/mllm.py
166
    def _prediction_to_result(
167
        self,
168
        results,
169
        params,
170
    ) -> Iterator:
171
        vector = np.zeros(len(self.project.subjects), dtype=np.float32)
172
        for subject_ids, score in results:
173
            for sid in subject_ids:
174
                vector[sid] += score
175
        return vector_to_suggestions(vector, int(params["limit"]))
176