| 1 |  |  | """Annif backend using the Omikuji classifier""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | from __future__ import annotations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import os.path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import shutil | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from typing import TYPE_CHECKING, Any | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | import omikuji | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | import annif.util | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from annif.exception import ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     NotInitializedException, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |     NotSupportedException, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     OperationFailedException, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | from annif.suggestion import SubjectSuggestion, SuggestionBatch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | from . import backend, mixins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  | if TYPE_CHECKING: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |     from scipy.sparse._csr import csr_matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |     from annif.corpus.document import DocumentCorpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  | class OmikujiBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |     """Omikuji based backend for Annif""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     name = "omikuji" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     # defaults for uninitialized instances | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |     _model = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |     TRAIN_FILE = "omikuji-train.txt" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     MODEL_FILE = "omikuji-model" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |     DEFAULT_PARAMETERS = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         "min_df": 1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |         "ngram": 1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         "cluster_balanced": True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |         "cluster_k": 2, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         "max_depth": 20, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         "collapse_every_n_layers": 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     } | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 45 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 46 |  |  |     def _initialize_model(self) -> None: | 
            
                                                                        
                            
            
                                    
            
            
                | 47 |  |  |         if self._model is None: | 
            
                                                                        
                            
            
                                    
            
            
                | 48 |  |  |             path = os.path.join(self.datadir, self.MODEL_FILE) | 
            
                                                                        
                            
            
                                    
            
            
                | 49 |  |  |             self.debug("loading model from {}".format(path)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |             if os.path.exists(path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |                 try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |                     self._model = omikuji.Model.load(path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |                 except RuntimeError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |                     raise OperationFailedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                         "Omikuji models trained on Annif versions older than " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |                         "0.56 cannot be loaded. Please retrain your project." | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |                     ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |                 raise NotInitializedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |                     "model {} not found".format(path), backend_id=self.backend_id | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |     def initialize(self, parallel: bool = False) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         self.initialize_vectorizer() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         self._initialize_model() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |     def _create_train_file(self, veccorpus: csr_matrix, corpus: DocumentCorpus) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         self.info("creating train file") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         path = os.path.join(self.datadir, self.TRAIN_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         with open(path, "w", encoding="utf-8") as trainfile: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |             # Extreme Classification Repository format header line | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |             # We don't yet know the number of samples, as some may be skipped | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |             print( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |                 "00000000", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |                 len(self.vectorizer.vocabulary_), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |                 len(self.project.subjects), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |                 file=trainfile, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |             ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |             n_samples = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |             for doc, vector in zip(corpus.documents, veccorpus): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |                 subject_ids = [str(subject_id) for subject_id in doc.subject_set] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |                 feature_values = [ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |                     "{}:{}".format(col, vector[row, col]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |                     for row, col in zip(*vector.nonzero()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |                 ] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |                 if not subject_ids or not feature_values: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |                     continue  # noqa | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |                 print(",".join(subject_ids), " ".join(feature_values), file=trainfile) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |                 n_samples += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |             # replace the number of samples value at the beginning | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |             trainfile.seek(0) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |             print("{:08d}".format(n_samples), end="", file=trainfile) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |     def _create_model(self, params: dict[str, Any], jobs: int) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         train_path = os.path.join(self.datadir, self.TRAIN_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |         model_path = os.path.join(self.datadir, self.MODEL_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         hyper_param = omikuji.Model.default_hyper_param() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |         hyper_param.cluster_balanced = annif.util.boolean(params["cluster_balanced"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         hyper_param.cluster_k = int(params["cluster_k"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         hyper_param.max_depth = int(params["max_depth"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         hyper_param.collapse_every_n_layers = int(params["collapse_every_n_layers"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         self._model = omikuji.Model.train_on_data(train_path, hyper_param, jobs or None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         if os.path.exists(model_path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |             shutil.rmtree(model_path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         self._model.save(os.path.join(self.datadir, self.MODEL_FILE)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |     def _train( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         corpus: DocumentCorpus, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         params: dict[str, Any], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         jobs: int = 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |     ) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         if corpus != "cached": | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |             if corpus.is_empty(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |                 raise NotSupportedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |                     "Cannot train omikuji project with no documents" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |             input = (doc.text for doc in corpus.documents) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |             vecparams = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |                 "min_df": int(params["min_df"]), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |                 "tokenizer": self.project.analyzer.tokenize_words, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                 "ngram_range": (1, int(params["ngram"])), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |             } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |             veccorpus = self.create_vectorizer(input, vecparams) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |             self._create_train_file(veccorpus, corpus) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |             self.info("Reusing cached training data from previous run.") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         self._create_model(params, jobs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |     def _suggest_batch( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         self, texts: list[str], params: dict[str, Any] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |     ) -> SuggestionBatch: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         vector = self.vectorizer.transform(texts) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |         limit = int(params["limit"]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |         batch_results = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |         for row in vector: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |             if row.nnz == 0:  # All zero vector, empty result | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |                 batch_results.append([]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |             feature_values = [(col, row[0, col]) for col in row.nonzero()[1]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |             results = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |             for subj_id, score in self._model.predict(feature_values, top_k=limit): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |                 results.append(SubjectSuggestion(subject_id=subj_id, score=score)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |             batch_results.append(results) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 148 |  |  |         return SuggestionBatch.from_sequence(batch_results, self.project.subjects) | 
            
                                                        
            
                                    
            
            
                | 149 |  |  |  |