| 1 |  |  | import abc | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | from warnings import warn | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | import artm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | from .trajectory import TrajectoryBuilder | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | import logging | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | logger = logging.getLogger(__name__) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  | class ArtmRegularizerWrapper(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |     __metaclass__ = abc.ABCMeta | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  |     subclasses = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |     labeling_ordering = ('tau', 'alpha_iter') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  |     _traj_type2traj_def_creator = {'alpha_iter': lambda x: '0_' + x[1], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |                               'tau': lambda x: '{}_{}'.format(x[0], x[1])} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     def __init__(self, parameters_dict, verbose=False): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |         self.trajectory_builder = TrajectoryBuilder() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |         self._regularizer = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |         self._alpha_iter_scalar = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |         self._trajectory_lambdas = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |         self._traj_def = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |         self._reg_constr_params = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |         self._params_for_labeling = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |         self._start = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |         self._name = parameters_dict.pop('name', 'no-name') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |         self._long_type = parameters_dict.pop('long-type', 'type-not-found') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |         if 'start' in parameters_dict: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |             self._start = parameters_dict['start'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |         elif type(parameters_dict['tau']) == str: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |             try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |                 _ = float(parameters_dict['tau']) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |             except ValueError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |                 self._start = parameters_dict['tau'].split('_')[0] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |                 parameters_dict['tau'] = '_'.join(parameters_dict['tau'].split('_')[1:]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |         for k, v in parameters_dict.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |             # in case it is none then parameter it is handled by the default behaviour of artm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |             if v is None or type(v) == list or type(v) == artm.dictionary.Dictionary or k == 'class_ids':  # one of topic_names or class_ids or dictionary | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |                 self._reg_constr_params[k] = v | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |             else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |                 try:  # by this point v should be a string, if exception occurs is shoud be only if a trajectory is defined (eg 'linear_-2_-10') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |                     vf = float(v) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |                     if k == 'alpha_iter': | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |                         self._alpha_iter_scalar = vf  # case: alpha_iter = a constant scalar which will be used for each of the 'nb_doument_passes' iterations | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |                     else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |                         self._reg_constr_params[k] = vf # case: parameter_name == 'tau' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |                     self._params_for_labeling[k] = vf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |                 except ValueError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |                     if self._start is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |                         print("INFO Defaulting to activating the regularizer from the 1st iteration") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                         self._start = 0 | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |                     self._traj_def[k] = self._traj_type2traj_def_creator[k]([self._start, v])  # case: parameter_value is a trajectory definition without the 'start' setting (nb of initial iterations that regularizer stays inactive) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |                     self._params_for_labeling[k] = self._traj_def[k] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |         self._create_artm_regularizer(dict(self._reg_constr_params, **{'name': self._name})) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     def _create_artm_regularizer(self, parameters): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |         self._regularizer = self._artm_constructor(**parameters) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         logger.info("Built '{}'/'{}' reg, named '{}', with settings: {}".format(self.type, self._long_type, self._name, '{'+', '.join(map(lambda x: '{}={}'.format(x[0], x[1]), parameters.items()))+'}')) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |     def register_subclass(cls, regularizer_type): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |         def decorator(subclass): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |             cls.subclasses[regularizer_type] = subclass | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |             return subclass | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         return decorator | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |     def create(cls, regularizer_type, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         if regularizer_type not in cls.subclasses: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |             raise ValueError("Bad regularizer type '{}'".format(regularizer_type)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         return cls.subclasses[regularizer_type](*args, **kwargs) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |     def label(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |         return '{}|{}'.format(self._name, '|'.join(map(lambda x: '{}:{}'.format(x[0][0], x[1]), self._get_labeling_data()))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |     def _get_labeling_data(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |         return sorted(self._params_for_labeling.items(), key=lambda x: x[0]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |     def get_tau_trajectory(self, collection_passes): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |         if 'tau' in self._traj_def: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |             return self._create_trajectory('tau', collection_passes) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |         return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |     def set_alpha_iters_trajectory(self, nb_document_passes): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         if 'alpha_iter' in self._traj_def: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |             self._regularizer.alpha_iter = list(self._create_trajectory('alpha_iter', nb_document_passes)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         elif self._alpha_iter_scalar: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |             self._regularizer.alpha_iter = [self._alpha_iter_scalar] * nb_document_passes | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |     def _create_trajectory(self, name, length): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |         _ = self._traj_def[name].split('_') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         return self.trajectory_builder.begin_trajectory('tau')\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |             .deactivate(int(_[0]))\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |             .interpolate_to(length - int(_[0]), float(_[3]), interpolation=_[1], start=float(_[2]))\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |             .create() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |     def static_parameters(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         return self._reg_constr_params | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |     def artm_regularizer(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         return self._regularizer | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |     def __str__(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         return self.name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |     def name(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |         return self._name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |     def type(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |         for k, v in ArtmRegularizerWrapper.subclasses.items(): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |             if type(self) == v: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |                 return k | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |     @property | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |     def long_type(self): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |         return self._long_type | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  | class SmoothSparseRegularizerWrapper(ArtmRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |     __metaclass__ = abc.ABCMeta | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |     def __init__(self, params_dict, targeted_topics): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |         :param params_dict: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         :param targeted_topics: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         if len(targeted_topics) == 0: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |             logger.warning("Did not specify topics to target with the '{}' {} regularizer. By default the Smooth regularizer will target all topics. " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |                            "This is recommended only if all your regularizers target all topics (no notion of background-domain separation). " | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |                            "If you are modeling an LDA (plsa_formula + smoothing regularization over all topics), ignore this warning.".format(params_dict['name'], type(self).__name__)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |             targeted_topics = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |         super(SmoothSparseRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': targeted_topics})) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  | class SmoothSparsePhiRegularizerWrapper(SmoothSparseRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |     _artm_constructor = artm.SmoothSparsePhiRegularizer | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |     def __init__(self, params_dict, topic_names, class_ids): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |         super(SmoothSparsePhiRegularizerWrapper, self).__init__(dict(params_dict, **{'class_ids': class_ids}), topic_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  | @ArtmRegularizerWrapper.register_subclass('sparse-phi') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  | class SparsePhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |     def __init__(self, params_dict, topic_names, class_ids): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |         super(SparsePhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  | @ArtmRegularizerWrapper.register_subclass('smooth-phi') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 151 |  |  | class SmoothPhiRegularizerWrapper(SmoothSparsePhiRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 152 |  |  |     def __init__(self, params_dict, topic_names, class_ids): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 153 |  |  |         super(SmoothPhiRegularizerWrapper, self).__init__(params_dict, topic_names, class_ids) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 154 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 155 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  | class SmoothSparseThetaRegularizerWrapper(SmoothSparseRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |     _artm_constructor = artm.SmoothSparseThetaRegularizer | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |     def __init__(self, params_dict, topic_names): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         super(SmoothSparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  | @ArtmRegularizerWrapper.register_subclass('sparse-theta') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  | class SparseThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |     def __init__(self, params_dict, topic_names): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |         super(SparseThetaRegularizerWrapper, self).__init__(params_dict, topic_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  | @ArtmRegularizerWrapper.register_subclass('smooth-theta') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  | class SmoothThetaRegularizerWrapper(SmoothSparseThetaRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |     def __init__(self, params_dict, topic_names): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |         super(SmoothThetaRegularizerWrapper, self).__init__(params_dict, topic_names) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  | @ArtmRegularizerWrapper.register_subclass('label-regularization-phi')  # can be used to expand the probability space to DxWxTxC eg author-topic model | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  | class DocumentClassificationRegularizerWrapper(ArtmRegularizerWrapper): | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 174 |  |  |     _artm_constructor = artm.LabelRegularizationPhiRegularizer | 
            
                                                                        
                            
            
                                    
            
            
                | 175 |  |  |     def __init__(self, params_dict, topic_names, dictionary=None, class_ids=None): | 
            
                                                                        
                            
            
                                    
            
            
                | 176 |  |  |         """ | 
            
                                                                        
                            
            
                                    
            
            
                | 177 |  |  |         :param str name: | 
            
                                                                        
                            
            
                                    
            
            
                | 178 |  |  |         :param dict params_dict: Can contain keys: 'tau', 'gamma', 'dictionary' | 
            
                                                                        
                            
            
                                    
            
            
                | 179 |  |  |         :param list of str topic_names: list of names of topics to regularize, will regularize all topics if not specified. | 
            
                                                                        
                            
            
                                    
            
            
                | 180 |  |  |             Should correspond to the domain topics | 
            
                                                                        
                            
            
                                    
            
            
                | 181 |  |  |         :param list of str class_ids: class_ids to regularize, will regularize all classes if not specified | 
            
                                                                        
                            
            
                                    
            
            
                | 182 |  |  |         :param dictionary: | 
            
                                                                        
                            
            
                                    
            
            
                | 183 |  |  |         :param class_ids: | 
            
                                                                        
                            
            
                                    
            
            
                | 184 |  |  |         """ | 
            
                                                                        
                            
            
                                    
            
            
                | 185 |  |  |         if len(topic_names) == 0: # T.O.D.O below: the warning should fire if smooth is active because then there must be defined | 
            
                                                                        
                            
            
                                    
            
            
                | 186 |  |  |             # non overlapping sets of 'domain' and 'background' topics | 
            
                                                                        
                            
            
                                    
            
            
                | 187 |  |  |             warn("Set DocumentClassificationRegularizer to target all topics. This is valid only if you do use 'background topics'.") | 
            
                                                                        
                            
            
                                    
            
            
                | 188 |  |  |             topic_names = None | 
            
                                                                        
                            
            
                                    
            
            
                | 189 |  |  |         super(DocumentClassificationRegularizerWrapper, self).__init__(dict(params_dict, **{'topic_names': topic_names, | 
            
                                                                        
                            
            
                                    
            
            
                | 190 |  |  |                                                                                             'dictionary': dictionary, | 
            
                                                                        
                            
            
                                    
            
            
                | 191 |  |  |                                                                                             'class_ids': class_ids})) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  | @ArtmRegularizerWrapper.register_subclass('decorrelate-phi') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  | class PhiDecorrelator(ArtmRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |     _artm_constructor = artm.DecorrelatorPhiRegularizer | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |     def __init__(self, params_dict, topic_names, class_ids=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |         super(PhiDecorrelator, self).__init__(dict(params_dict, **{'topic_names': topic_names, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |                                                                    'class_ids': class_ids})) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  | @ArtmRegularizerWrapper.register_subclass('improve-coherence') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  | class ImproveCoherence(ArtmRegularizerWrapper): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |     _artm_constructor = artm.ImproveCoherencePhiRegularizer # name=None, tau=1.0, class_ids=None, topic_names=None, dictionary=None, config=None) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |     def __init__(self, params_dict, topic_names, dictionary, class_ids=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |         super(ImproveCoherence, self).__init__(dict(params_dict, **{'dictionary': dictionary, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |                                                                     'topic_names': topic_names, | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 205 |  |  |                                                                     'class_ids': class_ids})) | 
            
                                                        
            
                                    
            
            
                | 206 |  |  |  |