| 1 |  |  | """Annif backend using the fastText classifier""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import collections | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import os.path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import annif.util | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from annif.suggestion import SubjectSuggestion, ListSuggestionResult | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | from annif.exception import NotInitializedException, NotSupportedException | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | import fasttext | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from . import backend | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from . import mixins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | class FastTextBackend(mixins.ChunkingBackend, backend.AnnifBackend): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     """fastText backend for Annif""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     name = "fasttext" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |     needs_subject_index = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     FASTTEXT_PARAMS = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         'lr': float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |         'lrUpdateRate': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |         'dim': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         'ws': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         'epoch': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         'minCount': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |         'neg': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         'wordNgrams': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         'loss': str, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         'bucket': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         'minn': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         'maxn': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |         'thread': int, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         't': float, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |         'pretrainedVectors': str | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |     DEFAULT_PARAMETERS = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |         'dim': 100, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |         'lr': 0.25, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         'epoch': 5, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |         'loss': 'hs', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |     } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     MODEL_FILE = 'fasttext-model' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |     TRAIN_FILE = 'fasttext-train.txt' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |     # defaults for uninitialized instances | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |     _model = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |     def default_params(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         params = backend.AnnifBackend.DEFAULT_PARAMETERS.copy() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         params.update(mixins.ChunkingBackend.DEFAULT_PARAMETERS) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         params.update(self.DEFAULT_PARAMETERS) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |         return params | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |     def _load_model(path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |         # monkey patch fasttext.FastText.eprint to avoid spurious warning | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |         # see https://github.com/facebookresearch/fastText/issues/1067 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |         orig_eprint = fasttext.FastText.eprint | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |         fasttext.FastText.eprint = lambda x:None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         model = fasttext.load_model(path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         # restore the original eprint | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         fasttext.FastText.eprint = orig_eprint | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         return model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |     def initialize(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         if self._model is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |             path = os.path.join(self.datadir, self.MODEL_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |             self.debug('loading fastText model from {}'.format(path)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |             if os.path.exists(path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |                 self._model = self._load_model(path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |                 self.debug('loaded model {}'.format(str(self._model))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |                 self.debug('dim: {}'.format(self._model.get_dimension())) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |                 raise NotInitializedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |                     'model {} not found'.format(path), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |                     backend_id=self.backend_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |     def _id_to_label(subject_id): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         return "__label__{:d}".format(subject_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |     def _label_to_subject(self, label): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         labelnum = label.replace('__label__', '') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         subject_id = int(labelnum) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         return self.project.subjects[subject_id] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |     def _write_train_file(self, doc_subjects, filename): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         with open(filename, 'w', encoding='utf-8') as trainfile: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |             for doc, subject_ids in doc_subjects.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |                 labels = [self._id_to_label(sid) for sid in subject_ids | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |                           if sid is not None] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |                 if labels: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |                     print(' '.join(labels), doc, file=trainfile) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |                 else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |                     self.warning('no labels for document "{}"'.format(doc)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |     def _normalize_text(self, text): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         return ' '.join(self.project.analyzer.tokenize_words(text)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |     def _create_train_file(self, corpus): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         self.info('creating fastText training file') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         doc_subjects = collections.defaultdict(set) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         for doc in corpus.documents: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |             text = self._normalize_text(doc.text) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |             if text == '': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |                 continue | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |             doc_subjects[text] = [self.project.subjects.by_uri(uri) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |                                   for uri in doc.uris] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         annif.util.atomic_save(doc_subjects, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |                                self.datadir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |                                self.TRAIN_FILE, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |                                method=self._write_train_file) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |     def _create_model(self, params): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |         self.info('creating fastText model') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |         trainpath = os.path.join(self.datadir, self.TRAIN_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |         modelpath = os.path.join(self.datadir, self.MODEL_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |         params = {param: self.FASTTEXT_PARAMS[param](val) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                   for param, val in params.items() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |                   if param in self.FASTTEXT_PARAMS} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         self.debug('Model parameters: {}'.format(params)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         self._model = fasttext.train_supervised(trainpath, **params) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         self._model.save_model(modelpath) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |     def _train(self, corpus, params): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         if corpus != 'cached': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |             if corpus.is_empty(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |                 raise NotSupportedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |                     'training backend {} with no documents' .format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |                         self.backend_id)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |             self._create_train_file(corpus) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |             self.info("Reusing cached training data from previous run.") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |         self._create_model(params) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |     def _predict_chunks(self, chunktexts, limit): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         return self._model.predict(list( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |             filter(None, [self._normalize_text(chunktext) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |                           for chunktext in chunktexts])), limit) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 145 |  |  |  | 
            
                                                        
            
                                    
            
            
                | 146 |  |  |     def _suggest_chunks(self, chunktexts, params): | 
            
                                                        
            
                                    
            
            
                | 147 |  |  |         limit = int(params['limit']) | 
            
                                                        
            
                                    
            
            
                | 148 |  |  |         chunklabels, chunkscores = self._predict_chunks( | 
            
                                                        
            
                                    
            
            
                | 149 |  |  |             chunktexts, limit) | 
            
                                                        
            
                                    
            
            
                | 150 |  |  |         label_scores = collections.defaultdict(float) | 
            
                                                        
            
                                    
            
            
                | 151 |  |  |         for labels, scores in zip(chunklabels, chunkscores): | 
            
                                                        
            
                                    
            
            
                | 152 |  |  |             for label, score in zip(labels, scores): | 
            
                                                        
            
                                    
            
            
                | 153 |  |  |                 label_scores[label] += score | 
            
                                                        
            
                                    
            
            
                | 154 |  |  |         best_labels = sorted([(score, label) | 
            
                                                        
            
                                    
            
            
                | 155 |  |  |                               for label, score in label_scores.items()], | 
            
                                                        
            
                                    
            
            
                | 156 |  |  |                              reverse=True) | 
            
                                                        
            
                                    
            
            
                | 157 |  |  |  | 
            
                                                        
            
                                    
            
            
                | 158 |  |  |         results = [] | 
            
                                                        
            
                                    
            
            
                | 159 |  |  |         for score, label in best_labels[:limit]: | 
            
                                                        
            
                                    
            
            
                | 160 |  |  |             subject = self._label_to_subject(label) | 
            
                                                        
            
                                    
            
            
                | 161 |  |  |             results.append(SubjectSuggestion( | 
            
                                                        
            
                                    
            
            
                | 162 |  |  |                 uri=subject[0], | 
            
                                                        
            
                                    
            
            
                | 163 |  |  |                 label=subject[1], | 
            
                                                        
            
                                    
            
            
                | 164 |  |  |                 notation=subject[2], | 
            
                                                        
            
                                    
            
            
                | 165 |  |  |                 score=score / len(chunktexts))) | 
            
                                                        
            
                                    
            
            
                | 166 |  |  |         return ListSuggestionResult(results, self.project.subjects) | 
            
                                                        
            
                                    
            
            
                | 167 |  |  |  |