Passed
Push — experiment-embeddings-backend-... ( bc5f53 )
by Juho
06:59
created

EmbeddingsBackend._truncate_text()   A

Complexity

Conditions 1

Size

Total Lines 5
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import os.path
6
from typing import TYPE_CHECKING, Any
7
8
import numpy as np
9
import tiktoken
10
from openai import AzureOpenAI  # Try using huggingface client
11
12
# import annif.util
13
from annif.exception import NotInitializedException, NotSupportedException
14
from annif.suggestion import vector_to_suggestions
15
16
from . import backend
17
18
if TYPE_CHECKING:
19
    from collections.abc import Iterator
20
21
    from annif.corpus.document import DocumentCorpus
22
23
24
class Vectorizer:
25
    def __init__(self, endpoint, model):
26
        self.model = model
27
        self.client = AzureOpenAI(  # TODO Try AsyncAzureOpenAI(
28
            azure_endpoint=endpoint,
29
            api_key=os.getenv("AZURE_OPENAI_KEY"),
30
            api_version="2024-02-15-preview",
31
        )
32
33
    def transform(self, text):
34
        response = self.client.embeddings.create(
35
            input=text,
36
            model=self.model,
37
            # dimensions=dimensions,  # TODO Try with reduced dimensions
38
        )
39
        return response.data[0].embedding
40
41
42
class EmbeddingsBackend(backend.AnnifBackend):
43
    """Semantic vector space similarity based backend for Annif"""
44
45
    name = "embeddings"
46
    _index = None
47
48
    INDEX_FILE = "emdeddings-index.npy"
49
    BASE_MODEL = "text-embedding-3-large"
50
    VECTOR_DIMENSIONS = 3072  # For text-embedding-3-large
51
    MAX_TOKENS = 8192  # For text-embedding-3-large
52
53
    encoding = tiktoken.encoding_for_model(BASE_MODEL)
54
55
    def _initialize_index(self) -> None:
56
        if self._index is None:
57
            path = os.path.join(self.datadir, self.INDEX_FILE)
58
            self.debug("loading similarity index from {}".format(path))
59
            if os.path.exists(path):
60
                self._index = np.load(path, allow_pickle=True)
61
            else:
62
                raise NotInitializedException(
63
                    "similarity index {} not found".format(path),
64
                    backend_id=self.backend_id,
65
                )
66
67
    def initialize(
68
        self,
69
        parallel: bool = False,
70
    ) -> None:
71
        self.vectorizer = Vectorizer(self.params["endpoint"], self.params["model"])
72
        self._initialize_index()
73
74
    def _create_index(self, corpus) -> None:
75
        self.vectorizer = Vectorizer(self.params["endpoint"], self.params["model"])
76
        self.info("creating similarity index")
77
        path = os.path.join(self.datadir, self.INDEX_FILE)
78
79
        subject_vectors = np.zeros((len(self.project.subjects), self.VECTOR_DIMENSIONS))
80
        for doc in corpus.documents:
81
            vec = self.vectorizer.transform(self._truncate_text(doc.text))
82
            for sid in doc.subject_set:
83
                subject_vectors[sid, :] = subject_vectors[sid, :] + vec
84
85
        row_norms = np.linalg.norm(subject_vectors, axis=1, keepdims=True)
86
87
        # Avoid division by zero: Only normalize non-zero rows
88
        self._index = np.where(row_norms == 0, 0, subject_vectors / row_norms)
89
        np.save(path, self._index, allow_pickle=True)
90
91
    def _truncate_text(self, text):
92
        """truncate text so it contains at most MAX_TOKENS according to the OpenAI
93
        tokenizer"""
94
        tokens = self.encoding.encode(text)
95
        return self.encoding.decode(tokens[: self.MAX_TOKENS])
96
97
    def _train(
98
        self,
99
        corpus: DocumentCorpus,
100
        params: dict[str, Any],
101
        jobs: int = 0,
102
    ) -> None:
103
        if corpus == "cached":
104
            raise NotSupportedException(
105
                "Training embeddings project from cached data not supported."
106
            )
107
        if corpus.is_empty():
108
            raise NotSupportedException(
109
                "Cannot train embeddings project with no documents"
110
            )
111
        self.info("transforming subject corpus")
112
        self._create_index(corpus)
113
114
    def _suggest(self, text: str, params: dict[str, Any]) -> Iterator:
115
        self.debug(
116
            'Suggesting subjects for text "{}..." (len={})'.format(text[:20], len(text))
117
        )
118
        truncated_text = self._truncate_text(text)
119
        vector = self.vectorizer.transform(truncated_text)
120
121
        cosine_similarity = np.dot(self._index, np.array(vector))
122
        return vector_to_suggestions(cosine_similarity, int(params["limit"]))
123