Passed
Push — experiment-embeddings-backend ( 5ef702...349504 )
by Juho
04:47
created

annif.backend.embeddings.Vectorizer.transform()   A

Complexity

Conditions 1

Size

Total Lines 7
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 2
dl 0
loc 7
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import os.path
6
import tempfile
7
from typing import TYPE_CHECKING, Any
8
9
from openai import AzureOpenAI
10
from qdrant_client import QdrantClient
11
from qdrant_client.models import Distance, VectorParams, Batch
12
13
import annif.util
14
from annif.exception import NotInitializedException, NotSupportedException
15
from annif.suggestion import vector_to_suggestions
16
17
from . import backend
18
19
if TYPE_CHECKING:
20
    from collections.abc import Iterator
21
22
    from scipy.sparse._csr import csr_matrix
23
24
    from annif.corpus.document import DocumentCorpus
25
26
27
class Vectorizer:
28
    def __init__(self, endpoint, model):
29
        self.model = model
30
        self.client = AzureOpenAI(  # TODO Try AsyncAzureOpenAI(
31
            azure_endpoint=endpoint,
32
            api_key=os.getenv("AZURE_OPENAI_KEY"),
33
            api_version="2024-02-15-preview",
34
        )
35
36
    def transform(self, text):
37
        response = self.client.embeddings.create(
38
            input=text,
39
            model=self.model,
40
            # dimensions=dimensions,  # TODO Try with reduced dimensions
41
        )
42
        return response.data[0].embedding
43
44
45
class EmbeddingsBackend(backend.AnnifBackend):
46
    """TODO xxx cector space similarity based backend for Annif"""
47
48
    name = "embeddings"
49
    is_trained = True
50
51
    # defaults for uninitialized instances
52
    _index = None
53
54
    DB_FILE = "qdrant-db"
55
    VECTOR_DIMENSIONS = 3072  # For text-embedding-3-large
56
    COLLECTION_NAME = "index-collection"
57
58
    def _initialize_index(self) -> None:
59
        if self._index is None:
60
            path = os.path.join(self.datadir, self.DB_FILE)
61
            self.debug("loading similarity index from {}".format(path))
62
            if os.path.exists(path):
63
                self.qdclient = QdrantClient(path=self.DB_FILE)
64
            else:
65
                raise NotInitializedException(
66
                    "similarity index {} not found".format(path),
67
                    backend_id=self.backend_id,
68
                )
69
70
    def initialize(
71
        self,
72
    ) -> None:
73
        self.vectorizer = Vectorizer(self.params["endpoint"], self.params["model"])
74
        self._initialize_index()
75
76
    def _create_index(self, corpus) -> None:
77
        self.vectorizer = Vectorizer(self.params["endpoint"], self.params["model"])
78
        self.info("creating similarity index")
79
        path = os.path.join(self.datadir, self.DB_FILE)
80
        self.qdclient = QdrantClient(path=path)
81
        self.qdclient.create_collection(
82
            collection_name=self.COLLECTION_NAME,
83
            vectors_config=VectorParams(
84
                size=self.VECTOR_DIMENSIONS, distance=Distance.DOT
85
            ),
86
        )
87
88
        veccorpus = (
89
            (doc.subject_set, self.vectorizer.transform(doc.text))
90
            for doc in corpus.documents
91
        )
92
93
        subject_sets, vectors = zip(*veccorpus)
94
        payloads = [{"subjects": [sid for sid in ss]} for ss in subject_sets]
95
        ids = list(range(len(vectors)))
96
        self.qdclient.upsert(
97
            collection_name=self.COLLECTION_NAME,
98
            points=Batch(
99
                ids=ids,
100
                vectors=vectors,
101
                payloads=payloads,
102
            ),
103
        )
104
        print(self.qdclient.get_collection(collection_name=self.COLLECTION_NAME))
105
        # print(
106
        #     self.qdclient.count(
107
        #         collection_name=self.COLLECTION_NAME,
108
        #         exact=True,
109
        #     ).count
110
        # )
111
112
    def _train(
113
        self,
114
        corpus: DocumentCorpus,
115
        params: dict[str, Any],
116
        jobs: int = 0,
117
    ) -> None:
118
        if corpus == "cached":
119
            raise NotSupportedException(
120
                "Training embeddings project from cached data not supported."
121
            )
122
        if corpus.is_empty():
123
            raise NotSupportedException(
124
                "Cannot train embeddings project with no documents"
125
            )
126
        self.info("transforming subject corpus")
127
        # self.initialize()
128
        self._create_index(corpus)
129
130
    def _suggest(self, text: str, params: dict[str, Any]) -> Iterator:
131
        self.debug(
132
            'Suggesting subjects for text "{}..." (len={})'.format(text[:20], len(text))
133
        )
134
        vector = self.vectorizer.transform([" ".join(text)])
135
        self.debug(
136
            f"Collection info: {self.qdclient.get_collection(collection_name=self.COLLECTION_NAME)}"
137
        )
138
        print(vector[:5])
139
        # get the most similar document from qdrant in here
140
        return vector_to_suggestions(self._index[vector], int(params["limit"]))
141