| 1 |  |  | import os | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | import re | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  | import warnings | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | from glob import glob | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | from functools import reduce | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | import attr | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | from topic_modeling_toolkit.results import ExperimentalResults | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  | from .fitness import FitnessFunction | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 |  |  | KERNEL_SUB_ENTITIES = ('coherence', 'contrast', 'purity', 'size') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 |  |  | MAX_DECIMALS = 2  # this consant should agree with the patm.modeling.experiment.Experiment.MAX_DECIMALS | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 |  |  | def _get_kernel_sub_hash(entity): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 |  |  |     assert entity in KERNEL_SUB_ENTITIES | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 |  |  |     return {'kernel-'+entity: {'scalar-extractor': lambda x,y: getattr(getattr(x.tracked, 'kernel'+y[2:]).average, entity).last if hasattr(x.tracked, 'kernel'+y[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 |  |  |                                'list-extractor': lambda x, y: getattr(getattr(x.tracked, 'kernel' + y[2:]).average, entity).all if hasattr(x.tracked, 'kernel' + y[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  |                                'column-title': lambda x: 'k'+entity[:3:2]+'.'+str(x)[2:], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 |  |  |                                'to-string': '{:.4f}', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 |  |  |                                'definitions': lambda x: ['kernel-{}-{}'.format(entity, y) for y in x.tracked.kernel_thresholds]}} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  | COLUMNS_HASH = { | 
            
                                                                                                            
                            
            
                                    
            
            
                | 27 |  |  |     'nb-topics': {'scalar-extractor': lambda x: x.scalars.nb_topics, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 |  |  |                   'column-title': lambda: 'tpcs'}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 |  |  |     'collection-passes': {'scalar-extractor': lambda x: x.scalars.dataset_iterations, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |                           'column-title': lambda: 'col-i'}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 |  |  |     'document-passes': {'scalar-extractor': lambda x: x.scalars.document_passes, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |                         'column-title': lambda: 'doc-i'}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 |  |  |     'total-phi-updates': {'scalar-extractor': lambda x: x.scalars.dataset_iterations * x.scalars.document_passes, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 |  |  |                           'column-title': lambda: 'phi-u'}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 |  |  |     'perplexity': {'scalar-extractor': lambda x: x.tracked.perplexity.last, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 |  |  |                    'list-extractor': lambda x: x.tracked.perplexity.all, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |                    'column-title': lambda: 'prpl', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |                    'to-string': '{:.1f}'}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |     'top-tokens-coherence': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'top'+str(y)).average_coherence.last if hasattr(x.tracked, 'top'+str(y)) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 |  |  |                              'list-extractor': lambda x,y: getattr(x.tracked, 'top'+str(y)).average_coherence.all if hasattr(x.tracked, 'top'+str(y)) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 |  |  |                              'column-title': lambda x: 'top'+str(x)+'ch', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 |  |  |                              'to-string': '{:.4f}', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |                              'definitions': lambda x: ['top-tokens-coherence-'+str(y) for y in x.tracked.top_tokens_cardinalities]}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |     'sparsity-phi': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'sparsity_phi_'+y).last if hasattr(x.tracked, 'sparsity_phi_'+y) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |                      'list-extractor': lambda x,y: getattr(x.tracked, 'sparsity_phi_'+y).all if hasattr(x.tracked, 'sparsity_phi_'+y) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |                      'column-title': lambda y: 'spp@'+y, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |                      'to-string': '{:.2f}', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |                      'definitions': lambda x: ['sparsity-phi-{}'.format(y) for y in x.tracked.modalities_initials]}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |     'sparsity-theta': {'scalar-extractor': lambda x: x.tracked.sparsity_theta.last, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |                        'list-extractor': lambda x: x.tracked.sparsity_theta.all, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 51 |  |  |                        'column-title': lambda: 'spt', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 52 |  |  |                        'to-string': '{:.2f}'}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 53 |  |  |     'background-tokens-ratio': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]).last if hasattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 54 |  |  |                                 # 'list-extractor': lambda x,y: getattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]).all if hasattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 55 |  |  |                                 'list-extractor': lambda x,y: getattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]).all if hasattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 56 |  |  |                                 'column-title': lambda x: 'btr.'+str(x)[2:], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 57 |  |  |                                 'to-string': '{:.2f}', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 58 |  |  |                                 'definitions': lambda x: ['background-tokens-ratio-{}'.format(y[:4]) for y in x.tracked.background_tokens_thresholds]}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |     'regularizers': {'scalar-extractor': lambda x: '[{}]'.format(', '.join(map(regularizers_format, x.regularizers))), | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |                      'column-title': lambda: 'regs',}, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |     'kernel-size': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'kernel'+y[2:]).average.size.last if hasattr(x.tracked, 'kernel'+y[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |                                'list-extractor': lambda x, y: getattr(x.tracked, 'kernel' + y[2:]).average.size.all if hasattr(x.tracked, 'kernel' + y[2:]) else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 |  |  |                                'column-title': lambda x: 'k'+'size'[:3:2]+'.'+str(x)[2:], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |                                'to-string': '{:.1f}', | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |                                'definitions': lambda x: ['kernel-size-{}'.format(y) for y in x.tracked.kernel_thresholds]} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  | } | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  | def regularizers_format(reg_def_string): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |     return reg_def_string | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |     # try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |     #     return '-'.join(re.findall(r"(?:^|-)(\w{1,4})\w*", reg_def_string[:reg_def_string.index("|")])) + reg_def_string[reg_def_string.index("|"):] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |     # except ValueError as e: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |     #     print(e) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 |  |  |     #     return reg_def_string | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  | COLUMNS_HASH = reduce(lambda x, y: dict(y, **x), [COLUMNS_HASH] + [_get_kernel_sub_hash(z) for z in KERNEL_SUB_ENTITIES]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  | class ResultsHandler(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |     _list_selector_hash = {str: lambda x: x[0] if x[1] == 'all' else None, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 |  |  |                            range: lambda x: [x[0][_] for _ in x[1]], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 |  |  |                            int: lambda x: x[0][:x[1]], | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 |  |  |                            list: lambda x: [x[0][_] for _ in x[1]]} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 |  |  |     _QUANTITY_2_EXTRACTOR_KEY = {'last': 'scalar', 'all': 'list'} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 |  |  |     DYNAMIC_COLUMNS = ['kernel-size', 'kernel-coherence', 'kernel-contrast', 'kernel-purity', 'top-tokens-coherence', 'sparsity-phi', 'background-tokens-ratio'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |     DEFAULT_COLUMNS = ['nb-topics', 'collection-passes', 'document-passes', 'total-phi-updates', 'perplexity'] +\ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 |  |  |                       DYNAMIC_COLUMNS[:-1] + ['sparsity-theta'] + [DYNAMIC_COLUMNS[-1]] + ['regularizers'] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |     def __init__(self, collections_root_path, results_dir_name='results'): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |         self._collections_root = collections_root_path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |         self._results_dir_name = results_dir_name | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |         self._results_hash = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         self._fitness_function_hash = {} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |         self._list_selector = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |     def get_experimental_results(self, collection_name, sort='', selection='all'): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         Call this method to get a list of experimental result objects from topic models trained on the given collection.\n | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |         :param str collection_name: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 |  |  |         :param str sort: if None the experimental results are obtained alphabetically on their json path | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         :param str or range or int or list selection: whether to select a subset of the experimental results fron the given collection\n | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |             - if selection == 'all', returns every experimental results object "extracted" from the jsons | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |             - if type(selection) == range, returns a "slice" of the experimental results based on the range | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |             - if type(selection) == int, returns the first n experimental results | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |             - if type(selection) == list, then it represents specific indices to sample the list of experimental results from | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 |  |  |         :return: the ExperimentalResults objects | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         :rtype: list | 
            
                                                                                                            
                            
            
                                    
            
            
                | 108 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 109 |  |  |         result_paths = glob('{}/*.json'.format(os.path.join(self._collections_root, collection_name, self._results_dir_name))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 110 |  |  |         if type(selection) == list and all(type(x) == str for x in selection):  # if input list contains model labels | 
            
                                                                                                            
                            
            
                                    
            
            
                | 111 |  |  |             e = self._get_experimental_results([_ for _ in result_paths if re.search(r'/(?:{})\.json'.format('|'.join(selection)), _)]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 112 |  |  |             if len(e) != len(selection): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 113 |  |  |                 raise ValueError("len1 = {}, len2 = {}\nseq1 = {}\nseq2 = {}".format(len(e), len(selection), [_.scalars.model_label for _ in e], selection)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 114 |  |  |             return e | 
            
                                                                                                            
                            
            
                                    
            
            
                | 115 |  |  |         self._list_selector = lambda y: ResultsHandler._list_selector_hash[type(selection)]([y, selection]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 116 |  |  |         r = self._get_experimental_results(result_paths, metric_sorter=self._get_metric(sort)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 117 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 118 |  |  |         assert len(result_paths) == len(r) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 119 |  |  |         return self._list_selector(r) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 120 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 121 |  |  |     def _get_experimental_results(self, results_paths, metric_sorter=None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 122 |  |  |         # print('_get_experimental_results.metric_sorter: {}'.format(metric_sorter)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 123 |  |  |         if metric_sorter is None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 124 |  |  |             return [self._process_result_path(_) for _ in sorted(results_paths)] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 125 |  |  |         return metric_sorter([self._process_result_path(x) for x in results_paths]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 126 |  |  |         # sorted([self._process_result_path(x) for x in results_paths], key=metric_sorter, reverse=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 127 |  |  |         # sorter = ResultsSorter.from_function(metric_sorter) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 128 |  |  |         # return sorter([self._process_result_path(x) for x in results_paths]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 129 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 130 |  |  |         # if metric_sorter: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 131 |  |  |         #     assert hasattr(metric_sorter, '__call__') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 132 |  |  |         #     print(' Metric function:', metric_sorter.__name__) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 133 |  |  |         # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 134 |  |  |         # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 135 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 136 |  |  |     def _process_result_path(self, result_path): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 137 |  |  |         if result_path not in self._results_hash: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 138 |  |  |             self._results_hash[result_path] = ExperimentalResults.create_from_json_file(result_path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 139 |  |  |         return self._results_hash[result_path] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 140 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 141 |  |  |     def _get_metric(self, metric): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 142 |  |  |         if not metric: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 143 |  |  |             return None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 144 |  |  |         if metric not in self._fitness_function_hash: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 145 |  |  |             self._fitness_function_hash[metric] = FitnessFunction.single_metric(metric) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 146 |  |  |         return MetricSorter(metric, lambda x: self._fitness_function_hash[metric].compute([ResultsHandler.extract(x, metric, 'last')])) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 147 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 148 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 149 |  |  |     def get_titles(column_definitions): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 150 |  |  |         return [ResultsHandler.get_abbreviation(x) for x in column_definitions] | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 151 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 152 |  |  |     @staticmethod | 
            
                                                                        
                            
            
                                    
            
            
                | 153 |  |  |     def get_abbreviation(definition): | 
            
                                                                        
                            
            
                                    
            
            
                | 154 |  |  |         tokens, parameters = ResultsHandler._parse_column_definition(definition) | 
            
                                                                        
                            
            
                                    
            
            
                | 155 |  |  |         return COLUMNS_HASH['-'.join(tokens)]['column-title'](*parameters) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 156 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 157 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 158 |  |  |     def stringnify(column, value): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 159 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 160 |  |  |         :param str column: key or definition; example values: 'perplexity', 'kernel-coherence', 'kernel-coherence-0.80' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 161 |  |  |         :param value: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 162 |  |  |         :return: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 163 |  |  |         :rtype str | 
            
                                                                                                            
                            
            
                                    
            
            
                | 164 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 165 |  |  |         return COLUMNS_HASH.get(column, COLUMNS_HASH[ResultsHandler._get_hash_key(column)]).get('to-string', '{}').format(value) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 166 |  |  |          | 
            
                                                                                                            
                            
            
                                    
            
            
                | 167 |  |  |     # @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 168 |  |  |     # def get_tau_trajectory(exp_results, matrix_name): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 169 |  |  |     #     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 170 |  |  |     # | 
            
                                                                                                            
                            
            
                                    
            
            
                | 171 |  |  |     #     :param results.experimental_results.ExperimentalResults exp_results: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 172 |  |  |     #     :param matrix_name: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 173 |  |  |     #     :return: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 174 |  |  |     #     """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 175 |  |  |     #     return getattr(exp_results.tracked.tau_trajectories, matrix_name).all | 
            
                                                                                                            
                            
            
                                    
            
            
                | 176 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 177 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 178 |  |  |     def extract(exp_results, column_definition, quantity): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 179 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 180 |  |  |         Call this method to query the given experimental results object about a specific metric. Supports requesting all | 
            
                                                                                                            
                            
            
                                    
            
            
                | 181 |  |  |         values tracked along the training process. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 182 |  |  |         :param results.experimental_results.ExperimentalResults exp_results: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 183 |  |  |         :param str column_definition: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 184 |  |  |         :param str quantity: must be one of {'last', 'all'} | 
            
                                                                                                            
                            
            
                                    
            
            
                | 185 |  |  |         :return: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 186 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 187 |  |  |         tokens, parameters = ResultsHandler._parse_column_definition(column_definition) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 188 |  |  |         return COLUMNS_HASH['-'.join(tokens)][ResultsHandler._QUANTITY_2_EXTRACTOR_KEY[quantity] + '-extractor'](*list([exp_results] + parameters)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 189 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 190 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 191 |  |  |     def get_all_columns(exp_results, requested_entities): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 192 |  |  |         return reduce(lambda i, j: i + j, | 
            
                                                                                                            
                            
            
                                    
            
            
                | 193 |  |  |                       [COLUMNS_HASH[x]['definitions'](exp_results) if x in ResultsHandler.DYNAMIC_COLUMNS else [x] for x in requested_entities]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 194 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 195 |  |  |     ###### UTILITY FUNCTIONS ###### | 
            
                                                                                                            
                            
            
                                    
            
            
                | 196 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 197 |  |  |     def _parse_column_definition(definition): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 198 |  |  |         return [list([_f for _f in y if _f]) for y in zip(*[(x, None) if ResultsHandler._is_token(x) else (None, x) for x in definition.split('-')])] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 199 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 200 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 201 |  |  |     def _get_hash_key(column_definition): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 202 |  |  |         return '-'.join([_ for _ in column_definition.split('-') if ResultsHandler._is_token(_)]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 203 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 204 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 205 |  |  |     def _is_token(definition_element): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 206 |  |  |         try: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 207 |  |  |             _ = float(definition_element) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 208 |  |  |             return False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 209 |  |  |         except ValueError: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 210 |  |  |             if definition_element[0] == '@' or len(definition_element) == 1: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 211 |  |  |                 return False | 
            
                                                                                                            
                            
            
                                    
            
            
                | 212 |  |  |             return True | 
            
                                                                                                            
                            
            
                                    
            
            
                | 213 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 214 |  |  |     @staticmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 215 |  |  |     def _label_selection(labels, experimental_results_list): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 216 |  |  |         """Returns the indices of the input labels based on the input experimental results list""" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 217 |  |  |         model_labels = [x.scalars.model_label for x in experimental_results_list] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 218 |  |  |         return [experimental_results_list.index(model_labels.index(l)) for l in labels if l in model_labels] | 
            
                                                                                                            
                            
            
                                    
            
            
                | 219 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 220 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 221 |  |  | ########################## | 
            
                                                                                                            
                            
            
                                    
            
            
                | 222 |  |  | def _build_sorter(instance, attribute, value): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 223 |  |  |     if not hasattr(value, '__call__'): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 224 |  |  |         raise TypeError("Second constructor argument should be a callable object, in case the first is 'alphabetical'") | 
            
                                                                                                            
                            
            
                                    
            
            
                | 225 |  |  |     # this is sorting from 'bigger' to 'smaller' (independently of how <,>, operators have been defined) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 226 |  |  |     instance.experimental_result_sorter = lambda exp_res_objs_list: sorted(exp_res_objs_list, key=value, reverse=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 227 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 228 |  |  | @attr.s(cmp=True, repr=True, str=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 229 |  |  | class MetricSorter(object): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 230 |  |  |     name = attr.ib(init=True, converter=str, cmp=True, repr=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 231 |  |  |     experimental_result_sorter = attr.ib(init=True, validator=_build_sorter, cmp=True, repr=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 232 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 233 |  |  |     @classmethod | 
            
                                                                                                            
                            
            
                                    
            
            
                | 234 |  |  |     def from_function(cls, function): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 235 |  |  |         return ResultsSorter(function.__name__, function) | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 236 |  |  |     def __call__(self, *args, **kwargs): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 237 |  |  |         return self.experimental_result_sorter(args[0]) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 238 |  |  | ############################ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 239 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 240 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 241 |  |  | if __name__ == '__main__': | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 242 |  |  |     ms = ModelSelector(collection_results_dir_path='/data/thesis/data/collections/dd/results') | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                        
            
                                    
            
            
                | 243 |  |  |  |