| 1 |  |  | import pprint | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import warnings | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from collections import Counter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from .regularization.regularizers_factory import REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | class TopicModel(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |     An instance of this class encapsulates the behaviour of a topic_model. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |     - LIMITATION: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |         Does not support dynamic changing of the number of topics between different training cycles. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |     def __init__(self, label, artm_model, evaluators): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |         Creates a topic model object. The model is ready to add regularizers to it.\n | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |         :param str label: a unique identifier for the model; model name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |         :param artm.ARTM artm_model: a reference to an artm model object | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         :param dict evaluators: mapping from evaluator-definitions strings to patm.evaluation.base_evaluator.ArtmEvaluator objects | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |          ; i.e. {'perplexity': obj0, 'sparsity-phi-\@dc': obj1, 'sparsity-phi-\@ic': obj2, 'topic-kernel-0.6': obj3, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |          'topic-kernel-0.8': obj4, 'top-tokens-10': obj5} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         self.label = label | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         self.artm_model = artm_model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         self._definition2evaluator = evaluators | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         self.definition2evaluator_name = {k: v.name for k, v in evaluators.items()} # {'sparsity-phi': 'spd_p', 'top-tokens-10': 'tt10'} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         self._evaluator_name2definition = {v.name: k for k, v in evaluators.items()} # {'sp_p': 'sparsity-phi', 'tt10': 'top-tokens-10'} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         self._reg_longtype2name = {}  # ie {'smooth-theta': 'smb_t', 'sparse-phi': 'spd_p'} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |         self._reg_name2wrapper = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |     def add_regularizer_wrapper(self, reg_wrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |         self._reg_name2wrapper[reg_wrapper.name] = reg_wrapper | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |         self._reg_longtype2name[reg_wrapper.long_type] = reg_wrapper.name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |         self.artm_model.regularizers.add(reg_wrapper.artm_regularizer) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |     def add_regularizer_wrappers(self, reg_wrappers): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         for _ in reg_wrappers: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |             self.add_regularizer_wrapper(_) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     def pformat_regularizers(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         def _filter(reg_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |             d = self._reg_name2wrapper[reg_name].static_parameters | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |             if 'topic_names' in d: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |                 del d['topic_names'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |             return d | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         return pprint.pformat({ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |             reg_long_type: dict(_filter(reg_name), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |                                 **{k:v for k,v in {'target topics': (lambda x: 'all' if len(x) == 0 else '[{}]'. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |                                                                                format(', '.join(x)))(self.get_reg_obj(reg_name).topic_names), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |                                                    'mods': getattr(self.get_reg_obj(reg_name), 'class_ids', None)}.items()} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                                 ) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |             for reg_long_type, reg_name in self._reg_longtype2name.items() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |         }) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     def pformat_modalities(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |         return pprint.pformat(self.modalities_dictionary) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |     def initialize_regularizers(self, collection_passes, document_passes): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |         Call this method before starting of training to build some regularizers settings from run-time parameters.\n | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |         :param int collection_passes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         :param int document_passes: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         self._trajs = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         for reg_name, wrapper in self._reg_name2wrapper.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |             self._trajs[reg_name] = wrapper.get_tau_trajectory(collection_passes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |             wrapper.set_alpha_iters_trajectory(document_passes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |     def get_reg_wrapper(self, reg_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         return self._reg_name2wrapper.get(reg_name, None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |     def regularizer_names(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         return sorted(_ for _ in self.artm_model.regularizers.data) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |     def regularizer_wrappers(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |         return list(map(lambda x: self._reg_name2wrapper[x], self.regularizer_names)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |     def regularizer_types(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         return map(lambda x: self._reg_name2wrapper[x].type, self.regularizer_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |     def regularizer_unique_types(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         return map(lambda x: self._reg_name2wrapper[x].long_type, self.regularizer_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |     def long_types_n_types(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         return map(lambda x: (self._reg_name2wrapper[x].long_type, self._reg_name2wrapper[x].type), self.regularizer_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |     def evaluator_names(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |         return sorted(_ for _ in self.artm_model.scores.data) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |     def evaluator_definitions(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         return [self._evaluator_name2definition[eval_name] for eval_name in self.evaluator_names] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |     def evaluators(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         return [self._definition2evaluator[ev_def] for ev_def in self.evaluator_definitions] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |     def tau_trajectories(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |         return filter(lambda x: x[1] is not None, self._trajs.items()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |     def nb_topics(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         return self.artm_model.num_topics | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |     def topic_names(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         return self.artm_model.topic_names | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |     def domain_topics(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |         """Returns the mostly agreed list of topic names found in all evaluators""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |         c = Counter() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         for definition, evaluator in self._definition2evaluator.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |             tn = getattr(evaluator, 'topic_names', None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |             if tn: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |             # if hasattr(evaluator.artm_score, 'topic_names'): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |                 # print definition, evaluator.artm_score.topic_names | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |                 # tn = evaluator.artm_score.topic_names | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |                 # if tn: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |                 c['+'.join(tn)] += 1 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         if len(c) > 2: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |             warnings.warn("There exist {} different subsets of all the topic names targeted by evaluators".format(len(c))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |         # print c.most_common() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |             return c.most_common(1)[0][0].split('+') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |         except IndexError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |             raise IndexError("Failed to compute domain topic names. Please enable at least one score with topics=domain_names argument. Scores: {}". | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |                              format(['{}: {}'.format(x.name, x.settings) for x in self.evaluators])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |     def background_topics(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |         return [topic_name for topic_name in self.artm_model.topic_names if topic_name not in self.domain_topics] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |     def background_tokens(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |         background_tokens_eval_name = '' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |         for eval_def, eval_name in self.definition2evaluator_name.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |             if eval_def.startswith('background-tokens-ratio-'): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  |                 background_tokens_eval_name = eval_name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |         if background_tokens_eval_name: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |             res = self.artm_model.score_tracker[background_tokens_eval_name].tokens | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |             return list(res[-1]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |     def modalities_dictionary(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |         return self.artm_model.class_ids | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |     def document_passes(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         return self.artm_model.num_document_passes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |     def get_reg_obj(self, reg_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |         return self.artm_model.regularizers[reg_name] | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 166 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 167 |  |  |     def get_reg_name(self, reg_type): | 
            
                                                                        
                            
            
                                    
            
            
                | 168 |  |  |         return self._reg_longtype2name[reg_type] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |         # try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         #     return self._reg_longtype2name[reg_type] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |         # except KeyError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |         #     print '{} is not found in {}'.format(reg_type, self._reg_longtype2name) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |         #     import sys | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |         #     sys.exit(1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |     def get_evaluator(self, eval_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |         return self._definition2evaluator[self._evaluator_name2definition[eval_name]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |     # def get_evaluator_by_def(self, definition): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |     #     return self._definition2evaluator[definition] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |     # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |     # def get_scorer_by_name(self, eval_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |     #     return self.artm_model.scores[eval_name] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |     # def get_targeted_topics_per_evaluator(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |     #     tps = [] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |     #     for evaluator in self.artm_model.scores.data: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |     #         if hasattr(self.artm_model.scores[evaluator], 'topic_names'): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |     #             tps.append((evaluator, self.artm_model.scores[evaluator].topic_names)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |     #     return tps | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |     # def get_formated_topic_names_per_evaluator(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |     #     return 'MODEL topic names:\n{}'.format('\n'.join(map(lambda x: ' {}: [{}]'.format(x[0], ', '.join(x[1])), self.get_targeted_topics_per_evaluator()))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |     @document_passes.setter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |     def document_passes(self, iterations): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |         self.artm_model.num_document_passes = iterations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |     def set_parameter(self, reg_name, reg_param, value): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |         if reg_name in self.artm_model.regularizers.data: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |             if hasattr(self.artm_model.regularizers[reg_name], reg_param): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |                 try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |                     # self.artm_model.regularizers[reg_name].__setattr__(reg_param, parameter_name2encoder[reg_param](value)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |                     self.artm_model.regularizers[reg_name].__setattr__(reg_param, value) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |                 except (AttributeError, TypeError) as e: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |                     print(e) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |                 raise ParameterNameNotFoundException("Regularizer '{}' with name '{}' has no attribute (parameter) '{}'".format(type(self.artm_model.regularizers[reg_name]).__name__, reg_name, reg_param)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |             raise RegularizerNameNotFoundException("Did not find a regularizer in the artm_model with name '{}'".format(reg_name)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |     def set_parameters(self, reg_name2param_settings): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |         for reg_name, settings in reg_name2param_settings.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |             for param, value in settings.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |                 self.set_parameter(reg_name, param, value) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |     def get_regs_param_dict(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  |         Returns a mapping between the model's regularizers unique string definition (ie shown in train.cfg: eg 'smooth-phi', 'sparse-theta', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  |         'label-regulariation-phi-cls', 'decorrelate-phi-dom-def') and their corresponding parameters that can be dynamically changed during training (eg 'tau', 'gamma').\n | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |         See patm.modeling.regularization.regularizers.REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH\n | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |         :return: the regularizer type (str) to parameters (list of strings) mapping (str => list) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |         :rtype: dict | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |         d = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  |         for unique_type, reg_type in self.long_types_n_types: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  |             d[unique_type] = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |             cur_reg_obj = self.artm_model.regularizers[self._reg_longtype2name[unique_type]] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |             for attribute_name in REGULARIZER_TYPE_2_DYNAMIC_PARAMETERS_HASH[reg_type]:  # ie for _ in ('tau', 'gamma') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |                 d[unique_type][attribute_name] = getattr(cur_reg_obj, attribute_name) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |         return d | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |     # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |     # def _get_header(self, max_lens, topic_names): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |     #     assert len(max_lens) == len(topic_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |     #     return ' - '.join(map(lambda x: '{}{}'.format(x[1], ' ' * (max_lens[x[0]] - len(name))), (j, name in enumerate(topic_names)))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  |     # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |     # def _get_rows(self, topic_name2tokens): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |     #     max_token_lens = [max(map(lambda x: len(x), topic_name2tokens[name])) for name in self.artm_model.topic_names] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  |     #     b = '' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 242 |  |  |     #     for i in range(len(topic_name2tokens.values()[0])): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 243 |  |  |     #         b += ' | '.join('{} {}'.format(topic_name2tokens[name][i], (max_token_lens[j] - len(topic_name2tokens[name][i])) * ' ') for j, name in enumerate(self.artm_model.topic_names)) + '\n' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 244 |  |  |     #     return b, max_token_lens | 
            
                                                                                                            
                            
            
                                    
            
            
                | 245 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 246 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 247 |  |  | class TrainSpecs(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 248 |  |  |     def __init__(self, collection_passes, reg_names, tau_trajectories): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 249 |  |  |         self._col_iter = collection_passes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 250 |  |  |         assert len(reg_names) == len(tau_trajectories) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 251 |  |  |         self._reg_name2tau_trajectory = dict(zip(reg_names, tau_trajectories)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 252 |  |  |         # print "TRAIN SPECS", self._col_iter, map(lambda x: len(x), self._reg_name2tau_trajectory.values()) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 253 |  |  |         assert all(map(lambda x: len(x) == self._col_iter, self._reg_name2tau_trajectory.values())) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 254 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 255 |  |  |     def tau_trajectory(self, reg_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 256 |  |  |         return self._reg_name2tau_trajectory.get(reg_name, None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 257 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 258 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 259 |  |  |     def tau_trajectory_list(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 260 |  |  |         """Returns the list of (reg_name, tau trajectory) pairs, sorted alphabetically by regularizer name""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 261 |  |  |         return sorted(self._reg_name2tau_trajectory.items(), key=lambda x: x[0]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 262 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 263 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 264 |  |  |     def collection_passes(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 265 |  |  |         return self._col_iter | 
            
                                                                                                            
                            
            
                                    
            
            
                | 266 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 267 |  |  |     def to_taus_slice(self, iter_count): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 268 |  |  |         return {reg_name: {'tau': trajectory[iter_count]} for reg_name, trajectory in self._reg_name2tau_trajectory.items()} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 269 |  |  |         # return dict(zip(self._reg_name2tau_trajectory.keys(), map(lambda x: x[iter_count], self._reg_name2tau_trajectory.values()))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 270 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 271 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 272 |  |  | class RegularizerNameNotFoundException(Exception): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 273 |  |  |     def __init__(self, msg): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 274 |  |  |         super(RegularizerNameNotFoundException, self).__init__(msg) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 275 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 276 |  |  | class ParameterNameNotFoundException(Exception): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 277 |  |  |     def __init__(self, msg): | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 278 |  |  |         super(ParameterNameNotFoundException, self).__init__(msg) | 
            
                                                        
            
                                    
            
            
                | 279 |  |  |  |