| 1 |  |  | """Annif backend using the Vorpal Wabbit multiclass and multilabel | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | classifiers""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import random | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | import os.path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | import annif.util | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | from vowpalwabbit import pyvw | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from annif.hit import VectorAnalysisResult | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | from annif.exception import ConfigurationException, NotInitializedException | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from . import backend | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | from . import mixins | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  | class VWMultiBackend(mixins.ChunkingBackend, backend.AnnifBackend): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     """Vorpal Wabbit multiclass/multilabel backend for Annif""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |     name = "vw_multi" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     needs_subject_index = True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |     VW_PARAMS = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |         # each param specifier is a pair (allowed_values, default_value) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         # where allowed_values is either a type or a list of allowed values | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         # and default_value may be None, to let VW decide by itself | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         'bit_precision': (int, None), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |         'ngram': (int, None), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         'learning_rate': (float, None), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         'loss_function': (['squared', 'logistic', 'hinge'], 'logistic'), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         'l1': (float, None), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         'l2': (float, None), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         'passes': (int, None), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |         'probabilities': (bool, None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |     } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     DEFAULT_ALGORITHM = 'oaa' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |     SUPPORTED_ALGORITHMS = ('oaa', 'ect', 'log_multi', 'multilabel_oaa') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |     MODEL_FILE = 'vw-model' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |     TRAIN_FILE = 'vw-train.txt' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |     # defaults for uninitialized instances | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |     _model = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     def initialize(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         if self._model is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |             path = os.path.join(self._get_datadir(), self.MODEL_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |             if not os.path.exists(path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |                 raise NotInitializedException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |                     'model {} not found'.format(path), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |                     backend_id=self.backend_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |             self.debug('loading VW model from {}'.format(path)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |             params = self._create_params({'i': path, 'quiet': True}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |             if 'passes' in params: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |                 # don't confuse the model with passes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                 del params['passes'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |             self.debug("model parameters: {}".format(params)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |             self._model = pyvw.vw(**params) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |             self.debug('loaded model {}'.format(str(self._model))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |     def algorithm(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         algorithm = self.params.get('algorithm', self.DEFAULT_ALGORITHM) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |         if algorithm not in self.SUPPORTED_ALGORITHMS: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |             raise ConfigurationException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |                 "{} is not a valid algorithm (allowed: {})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |                     algorithm, ', '.join(self.SUPPORTED_ALGORITHMS)), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |                 backend_id=self.backend_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         return algorithm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |     def _normalize_text(cls, project, text): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |         ntext = ' '.join(project.analyzer.tokenize_words(text)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         # colon and pipe chars have special meaning in VW and must be avoided | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |         return ntext.replace(':', '').replace('|', '') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     def _write_train_file(cls, examples, filename): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |         with open(filename, 'w') as trainfile: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |             for ex in examples: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |                 print(ex, file=trainfile) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |     def _uris_to_subject_ids(cls, project, uris): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |         subject_ids = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         for uri in uris: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |             subject_id = project.subjects.by_uri(uri) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |             if subject_id is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |                 subject_ids.append(subject_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         return subject_ids | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 90 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 91 |  |  |     def _format_examples(self, project, text, uris): | 
            
                                                                        
                            
            
                                    
            
            
                | 92 |  |  |         subject_ids = self._uris_to_subject_ids(project, uris) | 
            
                                                                        
                            
            
                                    
            
            
                | 93 |  |  |         if self.algorithm == 'multilabel_oaa': | 
            
                                                                        
                            
            
                                    
            
            
                | 94 |  |  |             yield '{} | {}'.format(','.join(map(str, subject_ids)), text) | 
            
                                                                        
                            
            
                                    
            
            
                | 95 |  |  |         else: | 
            
                                                                        
                            
            
                                    
            
            
                | 96 |  |  |             for subject_id in subject_ids: | 
            
                                                                        
                            
            
                                    
            
            
                | 97 |  |  |                 yield '{} | {}'.format(subject_id + 1, text) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |     def _create_train_file(self, corpus, project): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         self.info('creating VW train file') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         examples = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         for doc in corpus.documents: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |             text = self._normalize_text(project, doc.text) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |             examples.extend(self._format_examples(project, text, doc.uris)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         random.shuffle(examples) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         annif.util.atomic_save(examples, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |                                self._get_datadir(), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |                                self.TRAIN_FILE, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |                                method=self._write_train_file) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |     def _convert_param(self, param, val): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |         pspec, _ = self.VW_PARAMS[param] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         if isinstance(pspec, list): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |             if val in pspec: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |                 return val | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |             raise ConfigurationException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |                 "{} is not a valid value for {} (allowed: {})".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |                     val, param, ', '.join(pspec)), backend_id=self.backend_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |             return pspec(val) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |         except ValueError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |             raise ConfigurationException( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |                 "The {} value {} cannot be converted to {}".format( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |                     param, val, pspec), backend_id=self.backend_id) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |     def _create_params(self, params): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         params.update({param: defaultval | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |                        for param, (_, defaultval) in self.VW_PARAMS.items() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |                        if defaultval is not None}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         params.update({param: self._convert_param(param, val) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |                        for param, val in self.params.items() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |                        if param in self.VW_PARAMS}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         return params | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |     def _create_model(self, project): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |         self.info('creating VW model (algorithm: {})'.format(self.algorithm)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |         trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |         params = self._create_params( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |             {'data': trainpath, self.algorithm: len(project.subjects)}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |         if params.get('passes', 1) > 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |             # need a cache file when there are multiple passes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |             params.update({'cache': True, 'kill_cache': True}) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |         self.debug("model parameters: {}".format(params)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |         self._model = pyvw.vw(**params) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |         modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |         self._model.save(modelpath) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |     def train(self, corpus, project): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |         self._create_train_file(corpus, project) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |         self._create_model(project) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |     def _analyze_chunks(self, chunktexts, project): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |         results = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |         for chunktext in chunktexts: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |             example = ' | {}'.format(chunktext) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |             result = self._model.predict(example) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |             if self.algorithm == 'multilabel_oaa': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |                 # result is a list of subject IDs - need to vectorize | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |                 mask = np.zeros(len(project.subjects)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |                 mask[result] = 1.0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |                 result = mask | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |             elif isinstance(result, int): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |                 # result is a single integer - need to one-hot-encode | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |                 mask = np.zeros(len(project.subjects)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |                 mask[result - 1] = 1.0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |                 result = mask | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |                 result = np.array(result) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |             results.append(result) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         return VectorAnalysisResult( | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 171 |  |  |             np.array(results).mean(axis=0), project.subjects) | 
            
                                                        
            
                                    
            
            
                | 172 |  |  |  |