Passed
Push — experiment-embeddings-backend ( 5ef702 )
by Juho
06:51
created

EmbeddingsBackend._train()   A

Complexity

Conditions 3

Size

Total Lines 16
Code Lines 14

Duplication

Lines 16
Ratio 100 %

Importance

Changes 0
Metric Value
cc 3
eloc 14
nop 4
dl 16
loc 16
rs 9.7
c 0
b 0
f 0
1
"""TODO"""
2
3
from __future__ import annotations
4
5
import os.path
6
import tempfile
7
from typing import TYPE_CHECKING, Any
8
9
from openai import AzureOpenAI
10
11
import annif.util
12
from annif.exception import NotInitializedException, NotSupportedException
13
from annif.suggestion import vector_to_suggestions
14
15
from . import backend
16
17
if TYPE_CHECKING:
18
    from collections.abc import Iterator
19
20
    from scipy.sparse._csr import csr_matrix
21
22
    from annif.corpus.document import DocumentCorpus
23
24
25
class Vectorizer:
26
    def __init__(self, endpoint, model):
27
        self.model = model
28
        self.client = AzureOpenAI(  # TODO Try AsyncAzureOpenAI(
29
            azure_endpoint=endpoint,
30
            api_key=os.getenv("AZURE_OPENAI_KEY"),
31
            api_version="2024-02-15-preview",
32
        )
33
34
    def vectorize(self, text):
35
        response = self.client.embeddings.create(
36
            input=text,
37
            model=self.model,
38
            # dimensions=dimensions,  # TODO Try with reduced dimensions
39
        )
40
        return response.data[0].embedding
41
42
43 View Code Duplication
class EmbeddingsBackend(backend.AnnifBackend):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
44
    """TODO xxx cector space similarity based backend for Annif"""
45
46
    name = "embeddings"
47
48
    # defaults for uninitialized instances
49
    _index = None
50
51
    # INDEX_FILE = "tfidf-index"
52
53
    def _generate_subjects_from_documents(
54
        self, corpus: DocumentCorpus
55
    ) -> Iterator[str]:
56
        with tempfile.TemporaryDirectory() as tempdir:
57
            subject_buffer = {}
58
            for subject_id in range(len(self.project.subjects)):
59
                subject_buffer[subject_id] = SubjectBuffer(tempdir, subject_id)
60
61
            for doc in corpus.documents:
62
                tokens = self.project.analyzer.tokenize_words(doc.text)
63
                for subject_id in doc.subject_set:
64
                    subject_buffer[subject_id].write(" ".join(tokens))
65
66
            for sid in range(len(self.project.subjects)):
67
                yield subject_buffer[sid].read()
68
69
    def _initialize_index(self) -> None:
70
        if self._index is None:
71
            path = os.path.join(self.datadir, self.INDEX_FILE)
72
            self.debug("loading similarity index from {}".format(path))
73
            if os.path.exists(path):
74
                self._index = gensim.similarities.SparseMatrixSimilarity.load(path)
75
            else:
76
                raise NotInitializedException(
77
                    "similarity index {} not found".format(path),
78
                    backend_id=self.backend_id,
79
                )
80
81
    def initialize(self, parallel: bool = False) -> None:
82
        self.vectorizer = Vectorizer()
83
        self._initialize_index()
84
85
    def _create_index(self, veccorpus: csr_matrix) -> None:
86
        self.info("creating similarity index")
87
        gscorpus = Sparse2Corpus(veccorpus, documents_columns=False)
88
        self._index = gensim.similarities.SparseMatrixSimilarity(
89
            gscorpus, num_features=len(self.vectorizer.vocabulary_)
90
        )
91
        annif.util.atomic_save(self._index, self.datadir, self.INDEX_FILE)
92
93
    def _train(
94
        self,
95
        corpus: DocumentCorpus,
96
        params: dict[str, Any],
97
        jobs: int = 0,
98
    ) -> None:
99
        if corpus == "cached":
100
            raise NotSupportedException(
101
                "Training tfidf project from cached data not supported."
102
            )
103
        if corpus.is_empty():
104
            raise NotSupportedException("Cannot train tfidf project with no documents")
105
        self.info("transforming subject corpus")
106
        subjects = self._generate_subjects_from_documents(corpus)
107
        veccorpus = self.create_vectorizer(subjects)
108
        self._create_index(veccorpus)
109
110
    def _suggest(self, text: str, params: dict[str, Any]) -> Iterator:
111
        self.debug(
112
            'Suggesting subjects for text "{}..." (len={})'.format(text[:20], len(text))
113
        )
114
        tokens = self.project.analyzer.tokenize_words(text)
115
        vectors = self.vectorizer.transform([" ".join(tokens)])
116
        return vector_to_suggestions(self._index[vectors[0]], int(params["limit"]))
117