| 1 |  |  | """Backend that returns most similar subjects based on similarity in sparse | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | TF-IDF normalized bag-of-words vector space""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from __future__ import annotations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | import os.path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | import tempfile | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | from typing import TYPE_CHECKING, Any | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from scipy.sparse import csr_array, load_npz, save_npz | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from sklearn.preprocessing import normalize | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | import annif.util | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | from annif.exception import ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |     NotInitializedException, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     NotSupportedException, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     OperationFailedException, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  | from annif.suggestion import SuggestionBatch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  | from . import backend, mixins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  | if TYPE_CHECKING: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |     from collections.abc import Iterator | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |     from annif.corpus import Document, DocumentCorpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  | class SubjectBuffer: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |     """A file-backed buffer to store and retrieve subject text.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |     BUFFER_SIZE = 100 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |     def __init__(self, tempdir: str, subject_id: int) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |         filename = "{:08d}.txt".format(subject_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |         self._path = os.path.join(tempdir, filename) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |         self._buffer = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         self._created = False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |     def flush(self) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |         if self._created: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |             mode = "a" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |             mode = "w" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         with open(self._path, mode, encoding="utf-8") as subjfile: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |             for text in self._buffer: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |                 print(text, file=subjfile) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         self._buffer = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         self._created = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |     def write(self, text: str) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |         self._buffer.append(text) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |         if len(self._buffer) >= self.BUFFER_SIZE: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |             self.flush() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |     def read(self) -> str: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |         if not self._created: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |             # file was never created - we can simply return the buffer content | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |             return "\n".join(self._buffer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |             with open(self._path, "r", encoding="utf-8") as subjfile: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |                 return subjfile.read() + "\n" + "\n".join(self._buffer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  | class TFIDFBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |     """TF-IDF vector space similarity based backend for Annif""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     name = "tfidf" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |     # defaults for uninitialized instances | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |     _tfidf_matrix = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |     MATRIX_FILE = "tfidf-matrix.npz" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |     OLD_INDEX_FILE = "tfidf-index" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |     def _generate_subjects_from_documents( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         self, corpus: DocumentCorpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |     ) -> Iterator[str]: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         with tempfile.TemporaryDirectory() as tempdir: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |             subject_buffer = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |             for subject_id in range(len(self.project.subjects)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |                 subject_buffer[subject_id] = SubjectBuffer(tempdir, subject_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |             for doc in corpus.documents: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |                 tokens = self.project.analyzer.tokenize_words(doc.text) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |                 for subject_id in doc.subject_set: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |                     subject_buffer[subject_id].write(" ".join(tokens)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |             for sid in range(len(self.project.subjects)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |                 yield subject_buffer[sid].read() | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 93 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 94 |  |  |     def _initialize_index(self) -> None: | 
            
                                                                        
                            
            
                                    
            
            
                | 95 |  |  |         if self._tfidf_matrix is None: | 
            
                                                                        
                            
            
                                    
            
            
                | 96 |  |  |             path = os.path.join(self.datadir, self.MATRIX_FILE) | 
            
                                                                        
                            
            
                                    
            
            
                | 97 |  |  |             self.debug("loading tf-idf matrix from {}".format(path)) | 
            
                                                                        
                            
            
                                    
            
            
                | 98 |  |  |             if os.path.exists(path): | 
            
                                                                        
                            
            
                                    
            
            
                | 99 |  |  |                 self._tfidf_matrix = load_npz(path) | 
            
                                                                        
                            
            
                                    
            
            
                | 100 |  |  |             elif os.path.exists(os.path.join(self.datadir, self.OLD_INDEX_FILE)): | 
            
                                                                        
                            
            
                                    
            
            
                | 101 |  |  |                 raise OperationFailedException( | 
            
                                                                        
                            
            
                                    
            
            
                | 102 |  |  |                     "TFIDF models trained on Annif versions older than 1.4 cannot be " | 
            
                                                                        
                            
            
                                    
            
            
                | 103 |  |  |                     "loaded. Please retrain your project." | 
            
                                                                        
                            
            
                                    
            
            
                | 104 |  |  |                 ) | 
            
                                                                        
                            
            
                                    
            
            
                | 105 |  |  |             else: | 
            
                                                                        
                            
            
                                    
            
            
                | 106 |  |  |                 raise NotInitializedException( | 
            
                                                                        
                            
            
                                    
            
            
                | 107 |  |  |                     "tf-idf matrix {} not found".format(path), | 
            
                                                                        
                            
            
                                    
            
            
                | 108 |  |  |                     backend_id=self.backend_id, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |     def initialize(self, parallel: bool = False) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         self.initialize_vectorizer() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         self._initialize_index() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |     def _train( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         corpus: DocumentCorpus, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |         params: dict[str, Any], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         jobs: int = 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |     ) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |         if corpus == "cached": | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |             raise NotSupportedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |                 "Training tfidf project from cached data not supported." | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         if corpus.is_empty(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |             raise NotSupportedException("Cannot train tfidf project with no documents") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         self.info("transforming subject corpus") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         subjects = self._generate_subjects_from_documents(corpus) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |         # Note: Intentionally don't pass a tokenizer to the vectorizer here. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         # Instead the tokenization is done inside _generate_subjects_from_documents | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         # and in _suggest_batch. This way, the same train document doesn't have to be | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         # tokenized many times during training if it has many subjects. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         self._tfidf_matrix = normalize(self.create_vectorizer(subjects)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         self.info("saving tf-idf matrix") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         annif.util.atomic_save( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |             self._tfidf_matrix, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |             self.datadir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |             self.MATRIX_FILE, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |             lambda obj, filename: save_npz(filename, obj), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |     def _suggest_batch( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |         self, documents: list[Document], params: dict[str, Any] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |     ) -> SuggestionBatch: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |         query_vector = normalize( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |             self.vectorizer.transform( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |                 [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |                     " ".join(self.project.analyzer.tokenize_words(doc.text)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |                     for doc in documents | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |                 ] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         # Compute cosine similarity between query and indexed corpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |         similarities = query_vector @ self._tfidf_matrix.T | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |  | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 157 |  |  |         return SuggestionBatch(csr_array(similarities)).filter(int(params["limit"])) | 
            
                                                        
            
                                    
            
            
                | 158 |  |  |  |