Passed
Push — resolve-keras-userwarning ( 34af57 )
by Juho
03:31
created

annif.backend.tfidf.TFIDFBackend._suggest_batch()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 3
dl 0
loc 16
rs 9.95
c 0
b 0
f 0
1
"""Backend that returns most similar subjects based on similarity in sparse
2
TF-IDF normalized bag-of-words vector space"""
3
4
from __future__ import annotations
5
6
import os.path
7
import tempfile
8
from typing import TYPE_CHECKING, Any
9
10
from scipy.sparse import csr_array, load_npz, save_npz
11
from sklearn.preprocessing import normalize
12
13
import annif.util
14
from annif.exception import (
15
    NotInitializedException,
16
    NotSupportedException,
17
    OperationFailedException,
18
)
19
from annif.suggestion import SuggestionBatch
20
21
from . import backend, mixins
22
23
if TYPE_CHECKING:
24
    from collections.abc import Iterator
25
26
    from annif.corpus import Document, DocumentCorpus
27
28
29
class SubjectBuffer:
30
    """A file-backed buffer to store and retrieve subject text."""
31
32
    BUFFER_SIZE = 100
33
34
    def __init__(self, tempdir: str, subject_id: int) -> None:
35
        filename = "{:08d}.txt".format(subject_id)
36
        self._path = os.path.join(tempdir, filename)
37
        self._buffer = []
38
        self._created = False
39
40
    def flush(self) -> None:
41
        if self._created:
42
            mode = "a"
43
        else:
44
            mode = "w"
45
46
        with open(self._path, mode, encoding="utf-8") as subjfile:
47
            for text in self._buffer:
48
                print(text, file=subjfile)
49
50
        self._buffer = []
51
        self._created = True
52
53
    def write(self, text: str) -> None:
54
        self._buffer.append(text)
55
        if len(self._buffer) >= self.BUFFER_SIZE:
56
            self.flush()
57
58
    def read(self) -> str:
59
        if not self._created:
60
            # file was never created - we can simply return the buffer content
61
            return "\n".join(self._buffer)
62
        else:
63
            with open(self._path, "r", encoding="utf-8") as subjfile:
64
                return subjfile.read() + "\n" + "\n".join(self._buffer)
65
66
67
class TFIDFBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend):
68
    """TF-IDF vector space similarity based backend for Annif"""
69
70
    name = "tfidf"
71
72
    # defaults for uninitialized instances
73
    _tfidf_matrix = None
74
75
    MATRIX_FILE = "tfidf-matrix.npz"
76
    OLD_INDEX_FILE = "tfidf-index"
77
78
    def _generate_subjects_from_documents(
79
        self, corpus: DocumentCorpus
80
    ) -> Iterator[str]:
81
        with tempfile.TemporaryDirectory() as tempdir:
82
            subject_buffer = {}
83
            for subject_id in range(len(self.project.subjects)):
84
                subject_buffer[subject_id] = SubjectBuffer(tempdir, subject_id)
85
86
            for doc in corpus.documents:
87
                tokens = self.project.analyzer.tokenize_words(doc.text)
88
                for subject_id in doc.subject_set:
89
                    subject_buffer[subject_id].write(" ".join(tokens))
90
91
            for sid in range(len(self.project.subjects)):
92
                yield subject_buffer[sid].read()
93
94
    def _initialize_index(self) -> None:
95
        if self._tfidf_matrix is None:
96
            path = os.path.join(self.datadir, self.MATRIX_FILE)
97
            self.debug("loading tf-idf matrix from {}".format(path))
98
            if os.path.exists(path):
99
                self._tfidf_matrix = load_npz(path)
100
            elif os.path.exists(os.path.join(self.datadir, self.OLD_INDEX_FILE)):
101
                raise OperationFailedException(
102
                    "TFIDF models trained on Annif versions older than 1.4 cannot be "
103
                    "loaded. Please retrain your project."
104
                )
105
            else:
106
                raise NotInitializedException(
107
                    "tf-idf matrix {} not found".format(path),
108
                    backend_id=self.backend_id,
109
                )
110
111
    def initialize(self, parallel: bool = False) -> None:
112
        self.initialize_vectorizer()
113
        self._initialize_index()
114
115
    def _train(
116
        self,
117
        corpus: DocumentCorpus,
118
        params: dict[str, Any],
119
        jobs: int = 0,
120
    ) -> None:
121
        if corpus == "cached":
122
            raise NotSupportedException(
123
                "Training tfidf project from cached data not supported."
124
            )
125
        if corpus.is_empty():
126
            raise NotSupportedException("Cannot train tfidf project with no documents")
127
        self.info("transforming subject corpus")
128
        subjects = self._generate_subjects_from_documents(corpus)
129
        # Note: Intentionally don't pass a tokenizer to the vectorizer here.
130
        # Instead the tokenization is done inside _generate_subjects_from_documents
131
        # and in _suggest_batch. This way, the same train document doesn't have to be
132
        # tokenized many times during training if it has many subjects.
133
        self._tfidf_matrix = normalize(self.create_vectorizer(subjects))
134
        self.info("saving tf-idf matrix")
135
        annif.util.atomic_save(
136
            self._tfidf_matrix,
137
            self.datadir,
138
            self.MATRIX_FILE,
139
            lambda obj, filename: save_npz(filename, obj),
140
        )
141
142
    def _suggest_batch(
143
        self, documents: list[Document], params: dict[str, Any]
144
    ) -> SuggestionBatch:
145
        query_vector = normalize(
146
            self.vectorizer.transform(
147
                [
148
                    " ".join(self.project.analyzer.tokenize_words(doc.text))
149
                    for doc in documents
150
                ]
151
            )
152
        )
153
154
        # Compute cosine similarity between query and indexed corpus
155
        similarities = query_vector @ self._tfidf_matrix.T
156
157
        return SuggestionBatch(csr_array(similarities)).filter(int(params["limit"]))
158