| 1 |  |  | """PAV ensemble backend that combines results from multiple projects and | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | learns which concept suggestions from each backend are trustworthy using the | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | PAV algorithm, a.k.a. isotonic regression, to turn raw scores returned by | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | individual backends into probabilities.""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | import os.path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | import joblib | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | from scipy.sparse import coo_matrix, csc_matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from sklearn.isotonic import IsotonicRegression | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | import annif.corpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | import annif.suggestion | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  | import annif.util | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | from annif.exception import NotInitializedException, NotSupportedException | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | from . import ensemble | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | class PAVBackend(ensemble.BaseEnsembleBackend): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     """PAV ensemble backend that combines results from multiple projects""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |     name = "pav" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |     MODEL_FILE_PREFIX = "pav-model-" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |     # defaults for uninitialized instances | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |     _models = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |     DEFAULT_PARAMETERS = {'min-docs': 10} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     def initialize(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         super().initialize() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         if self._models is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |             return  # already initialized | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         self._models = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |         sources = annif.util.parse_sources(self.params['sources']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |         for source_project_id, _ in sources: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |             model_filename = self.MODEL_FILE_PREFIX + source_project_id | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |             path = os.path.join(self.datadir, model_filename) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |             if os.path.exists(path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |                 self.debug('loading PAV model from {}'.format(path)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |                 self._models[source_project_id] = joblib.load(path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |                 raise NotInitializedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |                     "PAV model file '{}' not found".format(path), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |                     backend_id=self.backend_id) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 45 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 46 |  |  |     def _get_model(self, source_project_id): | 
            
                                                                        
                            
            
                                    
            
            
                | 47 |  |  |         self.initialize() | 
            
                                                                        
                            
            
                                    
            
            
                | 48 |  |  |         return self._models[source_project_id] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |     def _normalize_hits(self, hits, source_project): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |         reg_models = self._get_model(source_project.project_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |         pav_result = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |         for hit in hits.as_list(source_project.subjects): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |             if hit.uri in reg_models: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                 score = reg_models[hit.uri].predict([hit.score])[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |             else:  # default to raw score | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |                 score = hit.score | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |             pav_result.append( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |                 annif.suggestion.SubjectSuggestion( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |                     uri=hit.uri, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |                     label=hit.label, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |                     notation=hit.notation, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |                     score=score)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         pav_result.sort(key=lambda hit: hit.score, reverse=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         return annif.suggestion.ListSuggestionResult(pav_result) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |     def _suggest_train_corpus(source_project, corpus): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         # lists for constructing score matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         data, row, col = [], [], [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         # lists for constructing true label matrix | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |         trow, tcol = [], [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         ndocs = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         for docid, doc in enumerate(corpus.documents): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |             hits = source_project.suggest(doc.text) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |             vector = hits.as_vector(source_project.subjects) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |             for cid in np.flatnonzero(vector): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |                 data.append(vector[cid]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |                 row.append(docid) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |                 col.append(cid) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |             subjects = annif.corpus.SubjectSet((doc.uris, doc.labels)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |             for cid in np.flatnonzero( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |                     subjects.as_vector(source_project.subjects)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |                 trow.append(docid) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |                 tcol.append(cid) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |             ndocs += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         scores = coo_matrix((data, (row, col)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |                             shape=(ndocs, len(source_project.subjects)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |                             dtype=np.float32) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         true = coo_matrix((np.ones(len(trow), dtype=np.bool), (trow, tcol)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |                           shape=(ndocs, len(source_project.subjects)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |                           dtype=np.bool) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         return csc_matrix(scores), csc_matrix(true) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |     def _create_pav_model(self, source_project_id, min_docs, corpus): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         self.info("creating PAV model for source {}, min_docs={}".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |             source_project_id, min_docs)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         source_project = self.project.registry.get_project(source_project_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         # suggest subjects for the training corpus | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         scores, true = self._suggest_train_corpus(source_project, corpus) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         # create the concept-specific PAV regression models | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         pav_regressions = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         for cid in range(len(source_project.subjects)): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |             if true[:, cid].sum() < min_docs: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |                 continue  # don't create model b/c of too few examples | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |             reg = IsotonicRegression(out_of_bounds='clip') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |             cid_scores = scores[:, cid].toarray().flatten().astype(np.float64) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |             reg.fit(cid_scores, true[:, cid].toarray().flatten()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |             pav_regressions[source_project.subjects[cid][0]] = reg | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         self.info("created PAV model for {} concepts".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |             len(pav_regressions))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |         model_filename = self.MODEL_FILE_PREFIX + source_project_id | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         annif.util.atomic_save( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |             pav_regressions, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |             self.datadir, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |             model_filename, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |             method=joblib.dump) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |     def _train(self, corpus, params): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |         if corpus == 'cached': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |             raise NotSupportedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                 'Training pav project from cached data not supported.') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         if corpus.is_empty(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |             raise NotSupportedException('training backend {} with no documents' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |                                         .format(self.backend_id)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         self.info("creating PAV models") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |         sources = annif.util.parse_sources(self.params['sources']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         min_docs = int(params['min-docs']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         for source_project_id, _ in sources: | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 132 |  |  |             self._create_pav_model(source_project_id, min_docs, corpus) | 
            
                                                        
            
                                    
            
            
                | 133 |  |  |  |