| 1 |  |  | """Annif backend using the transformer variant of pecos.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import logging | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import os.path as osp | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from sys import stdout | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from typing import Any | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | import scipy.sparse as sp | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from pecos.utils.featurization.text.preprocess import Preprocessor | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from pecos.xmc.xtransformer import matcher, model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from pecos.xmc.xtransformer.model import XTransformer | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | from pecos.xmc.xtransformer.module import MLProblemWithText | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | from annif.corpus.document import DocumentCorpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | from annif.exception import NotInitializedException, NotSupportedException | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  | from annif.suggestion import SubjectSuggestion, SuggestionBatch | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | from annif.util import ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     apply_param_parse_config, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |     atomic_save, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |     atomic_save_folder, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     boolean, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  | ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  | from . import backend, mixins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  | class XTransformerBackend(mixins.TfidfVectorizerMixin, backend.AnnifBackend): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     """XTransformer based backend for Annif""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     name = "xtransformer" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |     needs_subject_index = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |     _model = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     train_X_file = "xtransformer-train-X.npz" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |     train_y_file = "xtransformer-train-y.npz" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |     train_txt_file = "xtransformer-train-raw.txt" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |     model_folder = "xtransformer-model" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     PARAM_CONFIG = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |         "min_df": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         "ngram": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         "fix_clustering": boolean, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         "nr_splits": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         "min_codes": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         "max_leaf_size": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |         "imbalanced_ratio": float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |         "imbalanced_depth": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         "max_match_clusters": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         "do_fine_tune": boolean, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         "model_shortcut": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         "beam_size": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |         "limit": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |         "post_processor": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |         "negative_sampling": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |         "ensemble_method": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |         "threshold": float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |         "loss_function": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         "truncate_length": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |         "hidden_droput_prob": float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         "batch_size": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         "gradient_accumulation_steps": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         "learning_rate": float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         "weight_decay": float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |         "adam_epsilon": float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         "num_train_epochs": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         "max_steps": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         "lr_schedule": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         "warmup_steps": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         "logging_steps": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |         "save_steps": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         "max_active_matching_labels": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         "max_num_labels_in_gpu": int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         "use_gpu": boolean, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         "bootstrap_model": str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |     DEFAULT_PARAMETERS = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |         "min_df": 1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |         "ngram": 1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         "fix_clustering": False, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |         "nr_splits": 16, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |         "min_codes": None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         "max_leaf_size": 100, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         "imbalanced_ratio": 0.0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         "imbalanced_depth": 100, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |         "max_match_clusters": 32768, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         "do_fine_tune": True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         "model_shortcut": "distilbert-base-multilingual-uncased", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         "beam_size": 20, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         "limit": 100, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         "post_processor": "sigmoid", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |         "negative_sampling": "tfn", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         "ensemble_method": "transformer-only", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |         "threshold": 0.1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         "loss_function": "squared-hinge", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         "truncate_length": 128, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |         "hidden_droput_prob": 0.1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         "batch_size": 32, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         "gradient_accumulation_steps": 1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         "learning_rate": 1e-4, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         "weight_decay": 0.0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         "adam_epsilon": 1e-8, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         "num_train_epochs": 1, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         "max_steps": 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         "lr_schedule": "linear", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         "warmup_steps": 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         "logging_steps": 100, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         "save_steps": 1000, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         "max_active_matching_labels": None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         "max_num_labels_in_gpu": 65536, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         "use_gpu": True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         "bootstrap_model": "linear", | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |     } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |     def _initialize_model(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |         if self._model is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |             path = osp.join(self.datadir, self.model_folder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |             self.debug("loading model from {}".format(path)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |             if osp.exists(path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |                 self._model = XTransformer.load(path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                 raise NotInitializedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |                     "model {} not found".format(path), backend_id=self.backend_id | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |                 ) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 127 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 128 |  |  |     def initialize(self, parallel: bool = False) -> None: | 
            
                                                                        
                            
            
                                    
            
            
                | 129 |  |  |         self.initialize_vectorizer() | 
            
                                                                        
                            
            
                                    
            
            
                | 130 |  |  |         self._initialize_model() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |     def default_params(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         params.update(self.DEFAULT_PARAMETERS) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         return params | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |     def _create_train_files(self, veccorpus, corpus): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |         self.info("creating train file") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |         Xs = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |         ys = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |         txt_pth = osp.join(self.datadir, self.train_txt_file) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         with open(txt_pth, "w", encoding="utf-8") as txt_file: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |             for doc, vector in zip(corpus.documents, veccorpus): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |                 subject_set = doc.subject_set | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |                 if not (subject_set and doc.text): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |                     continue  # noqa | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |                 print(" ".join(doc.text.split()), file=txt_file) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |                 Xs.append(sp.csr_matrix(vector, dtype=np.float32).sorted_indices()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |                 ys.append( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |                     sp.csr_matrix( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |                         ( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |                             np.ones(len(subject_set)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |                             (np.zeros(len(subject_set)), [s for s in subject_set]), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |                         ), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |                         shape=(1, len(self.project.subjects)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |                         dtype=np.float32, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |                     ).sorted_indices() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         atomic_save( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |             sp.vstack(Xs, format="csr"), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |             self.datadir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |             self.train_X_file, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |             method=lambda mtrx, target: sp.save_npz(target, mtrx, compressed=True), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |         atomic_save( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |             sp.vstack(ys, format="csr"), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |             self.datadir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |             self.train_y_file, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |             method=lambda mtrx, target: sp.save_npz(target, mtrx, compressed=True), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |     def _create_model(self, params, jobs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |         train_txts = Preprocessor.load_data_from_file( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |             osp.join(self.datadir, self.train_txt_file), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |             label_text_path=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |             text_pos=0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |         )["corpus"] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |         train_X = sp.load_npz(osp.join(self.datadir, self.train_X_file)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |         train_y = sp.load_npz(osp.join(self.datadir, self.train_y_file)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |         model_path = osp.join(self.datadir, self.model_folder) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |         new_params = apply_param_parse_config(self.PARAM_CONFIG, self.params) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |         new_params["only_topk"] = new_params.pop("limit") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |         train_params = XTransformer.TrainParams.from_dict( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |             new_params, recursive=True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |         ).to_dict() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |         pred_params = XTransformer.PredParams.from_dict( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |             new_params, recursive=True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         ).to_dict() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |         self.info("Start training") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |         # enable progress | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |         matcher.LOGGER.setLevel(logging.DEBUG) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |         matcher.LOGGER.addHandler(logging.StreamHandler(stream=stdout)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |         model.LOGGER.setLevel(logging.DEBUG) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |         model.LOGGER.addHandler(logging.StreamHandler(stream=stdout)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |         self._model = XTransformer.train( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |             MLProblemWithText(train_txts, train_y, X_feat=train_X), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |             clustering=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |             val_prob=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |             train_params=train_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |             pred_params=pred_params, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |             beam_size=int(params["beam_size"]), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |             steps_scale=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |             label_feat=None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |         atomic_save_folder(self._model, model_path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |  | 
            
                                                                                                            
                            
            
                                                                    
                                                                                                        
            
            
                | 208 |  | View Code Duplication |     def _train( | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |         self, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |         corpus: DocumentCorpus, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |         params: dict[str, Any], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |         jobs: int = 0, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |     ) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |         if corpus == "cached": | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |             self.info("Reusing cached training data from previous run.") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |             if corpus.is_empty(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |                 raise NotSupportedException("Cannot t project with no documents") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |             input = (doc.text for doc in corpus.documents) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |             vecparams = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |                 "min_df": int(params["min_df"]), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |                 "tokenizer": self.project.analyzer.tokenize_words, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |                 "ngram_range": (1, int(params["ngram"])), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |             } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |             veccorpus = self.create_vectorizer(input, vecparams) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |             self._create_train_files(veccorpus, corpus) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |         self._create_model(params, jobs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |     def _suggest_batch( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |         self, texts: list[str], params: dict[str, Any] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |     ) -> SuggestionBatch: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |         vector = self.vectorizer.transform(texts) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |         if vector.nnz == 0:  # All zero vector, empty result | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |             return list() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |         new_params = apply_param_parse_config(self.PARAM_CONFIG, params) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |         prediction = self._model.predict( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |             texts, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |             X_feat=vector.sorted_indices(), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |             batch_size=new_params["batch_size"], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |             use_gpu=True, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |             only_top_k=new_params["limit"], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |             post_processor=new_params["post_processor"], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |         ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |         current_batchsize = prediction.get_shape()[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |         batch_result = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  |         for i in range(current_batchsize): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |             results = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |             row = prediction.getrow(i) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |             for idx, score in zip(row.indices, row.data): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |                 results.append(SubjectSuggestion(subject_id=idx, score=score)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |             batch_result.append(results) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 253 |  |  |         return SuggestionBatch.from_sequence(batch_result, self.project.subjects) | 
            
                                                        
            
                                    
            
            
                | 254 |  |  |  |