reporting.reporter   C
last analyzed

Complexity

Total Complexity 57

Size/Duplication

Total Lines 283
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 199
dl 0
loc 283
rs 5.04
c 0
b 0
f 0
wmc 57

26 Methods

Rating   Name   Duplication   Size   Complexity  
A ModelReporter._build_renderable() 0 15 4
A ModelReporter._length() 0 5 2
A ModelReporter.get_formatted_string() 0 16 1
A ModelReporter._get_label() 0 7 2
A ModelReporter._get_labels_n_values() 0 12 5
A ModelReporter._extract_all() 0 6 2
A ModelReporter.columns_to_render() 0 3 4
A ModelReporter.determine_maximal_set_of_renderable_columns() 0 4 2
A ModelReporter._get_invalid_column_definitions() 0 3 1
A ModelReporter._parse_column_definition() 0 3 2
A ModelReporter._get_renderable() 0 13 2
A ModelReporter._extract() 0 3 1
A ModelReporter._get_maximal_renderable_columns() 0 12 2
A ModelReporter._to_string() 0 8 4
A ModelReporter._containing_column() 0 2 1
A ModelReporter._is_token() 0 9 4
A ModelReporter._get_column_definitions() 0 5 2
B ModelReporter._initialize() 0 34 6
A ModelReporter.determine_maximal_set_of_renderable_columns_debug() 0 9 2
A ModelReporter._to_row() 0 5 1
A ModelReporter._get_hash_key() 0 3 1
A ModelReporter._results_value_vectors() 0 3 1
A ModelReporter.__init__() 0 13 1
A ModelReporter._compute_rows() 0 3 1
A ModelReporter._to_list_of_strings() 0 2 1
A ModelReporter.exp_results() 0 6 1

How to fix   Complexity   

Complexity

Complex classes like reporting.reporter often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import re
2
import os
3
import sys
4
from glob import glob
5
from collections import Iterable
6
7
from .fitness import FitnessCalculator
8
from .model_selection import ResultsHandler
9
from functools import reduce
10
11
12
import logging
13
logger = logging.getLogger(__name__)
14
logger.setLevel('INFO')
15
16
17
class ModelReporter:
18
    BOLD = '\033[1m'
19
    UNDERLINE = '\033[4m'
20
    ENDC = '\033[0m'
21
22
    def __init__(self, collections_root_path, results_dir_name='results'):
23
        self._collections_dir = collections_root_path
24
        self.results_handler = ResultsHandler(self._collections_dir, results_dir_name=results_dir_name)
25
        self._results_dir_name = results_dir_name
26
        self._label_separator = ':'
27
        self._columns_to_render = []
28
        self._column_keys_to_highlight = ['perplexity', 'kernel-coherence', 'kernel-purity', 'kernel-contrast', 'top-tokens-coherence',
29
                                          'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio']
30
        self.highlight_pre_fix = ModelReporter.UNDERLINE
31
        self.highlight_post_fix = ModelReporter.ENDC
32
        self.fitness_computer = FitnessCalculator()
33
        self._max_label_len = 0
34
        self._max_col_lens = []
35
36
    def get_formatted_string(self, collection_name, columns=None, metric='', verbose=True):
37
        """
38
        :param str collection_name:
39
        :param list columns:
40
        :param str metric:
41
        :param bool verbose:
42
        :return:
43
        :rtype: str
44
        """
45
        self._initialize(collection_name, columns=columns, metric=metric, verbose=verbose)
46
        body = '\n'.join(self._compute_rows())
47
        head = '{}{} {} {}'.format(' '*self._max_label_len,
48
                                   ' '*len(self._label_separator),
49
                                   ' '.join(['{}{}'.format(x[1], ' '*(self._max_col_lens[x[0]] - len(x[1]))) for x in enumerate(self._columns_titles[:-1])]),
50
                                   self._columns_titles[-1])
51
        return head + '\n' + body
52
53
    @property
54
    def exp_results(self):
55
        """
56
        :rtype: list of results.experimental_results.ExperimentalResults
57
        """
58
        return self.results_handler.get_experimental_results(self._collection_name, selection='all')
59
60
    def _initialize(self, collection_name, columns=None, metric='', verbose=False):
61
        self._collection_name = collection_name
62
        self._result_paths = glob('{}/*.json'.format(os.path.join(self._collections_dir, collection_name, self._results_dir_name)))
63
        self._model_labels = [ModelReporter._get_label(x) for x in self._result_paths]
64
        if not self._model_labels:
65
            raise RuntimeError("Either wrong dataset label '{}' was given or the collection/dataset has no trained models. Dataset root: '{}', contents: [{}]. results contents: [{}]".format(
66
                collection_name, os.path.join(self._collections_dir, collection_name), ', '.join(os.path.basename(x) for x in os.listdir(os.path.join(self._collections_dir, collection_name))),
67
                ', '.join(os.path.basename(x) for x in os.listdir(os.path.join(self._collections_dir, collection_name, 'results')))))
68
        self._max_label_len = max([len(x) for x in self._model_labels])
69
        self._columns_to_render, self._columns_failed = [], []
70
        self.maximal_requested_columns = self.determine_maximal_set_of_renderable_columns_debug(self.exp_results)
71
72
        # for res in self.exp_results:
73
        #     if set(self._containing_column(res)) != set(self.maximal_requested_columns):
74
        #         logger.warning("This fails in unittesting because trackables differ per model results. Maximal discovered: [{}] diff current model = {}, label: {}".format(
75
        #             ', '.join(sorted(list(set(self.maximal_requested_columns)))), ', '.join(sorted([_ for _ in self.maximal_requested_columns if _ not in self._containing_column(res)])), res.scalars.model_label))
76
77
        self._maximal_renderable_columns = self._get_maximal_renderable_columns()
78
79
        if not columns:
80
            self.columns_to_render = self._maximal_renderable_columns
81
        else:
82
            self.columns_to_render, self._columns_failed = self._get_renderable(self._maximal_renderable_columns, columns)
83
        if metric != 'alphabetical' and metric not in self.columns_to_render:
84
            raise InvalidMetricException( "Metric '{}' should be either 'alphabetical' or within the recognized ones [{}]".format(metric, ', '.join(self.columns_to_render)))
85
        self._metric = metric
86
        print("Input metric to sort on: '{}'".format(self._metric))
87
        if verbose:
88
            print('Using: [{}]'.format(', '.join(self.columns_to_render)))
89
            print('Ommiting: [{}]'.format(', '.join({_ for _ in self._maximal_renderable_columns if _ not in self.columns_to_render})))
90
            print('Failed: [{}]'.format(', '.join(self._columns_failed)))
91
        self._columns_titles = self.results_handler.get_titles(self.columns_to_render)
92
        self._max_col_lens = [len(x) for x in self._columns_titles]
93
        self.fitness_computer.highlightable_columns = [_ for _ in self.columns_to_render if ModelReporter._get_hash_key(_) in self._column_keys_to_highlight]
94
95
    @property
96
    def columns_to_render(self):
97
        return self._columns_to_render
98
99
    @columns_to_render.setter
100
    def columns_to_render(self, column_definitions):
101
        if not isinstance(column_definitions, Iterable):
102
            raise InvalidColumnsException(
103
                "Input column definitions are of type '{}' instead of iterable".format(type(column_definitions)))
104
        if not column_definitions:
105
            raise InvalidColumnsException('Input column definitions evaluates to None')
106
        invalid_columns = ModelReporter._get_invalid_column_definitions(column_definitions,
107
                                                                        self._maximal_renderable_columns)
108
        if invalid_columns:
109
            raise InvalidColumnsException(
110
                'Input column definitions [{}] are not valid'.format(', '.join(invalid_columns)))
111
        self._columns_to_render = column_definitions
112
113
    ########## COLUMNS DEFINITIONS ##########
114
    def _get_maximal_renderable_columns(self):
115
        """Call this method to get a list of all the inferred columns allowed to render."""
116
        _ = ModelReporter._get_column_definitions(self.results_handler.DEFAULT_COLUMNS,
117
                                                     self.maximal_requested_columns)
118
        if len(_) != len(self.maximal_requested_columns):
119
            raise RuntimeError("Discovered columns (from results): [{}]. Computed: [{}. Missmatch with supported DYNAMIC [{}] and DEFAULT [{}] columns.".format(
120
                ', '.join(sorted(self.maximal_requested_columns)),
121
                ', '.join(sorted(_)),
122
                ', '.join(sorted(self.results_handler.DYNAMIC_COLUMNS)),
123
                ', '.join(sorted(self.results_handler.DEFAULT_COLUMNS))
124
            ))
125
        return _
126
127
    @staticmethod
128
    def _get_column_definitions(supported_columns, requested_column_definitions):
129
        """Given a list of allowed column definitions, returns a sublist of it based on the selected supported_columns. The returned list is ordered
130
         based on the given supported_columns."""
131
        return reduce(lambda i,j: i+j, [sorted([_ for _ in requested_column_definitions if _.startswith(x)]) for x in supported_columns])
132
133
    def _containing_column(self, exp_res_obj):
134
        return self.results_handler.get_all_columns(exp_res_obj, self.results_handler.DEFAULT_COLUMNS)
135
136
    def determine_maximal_set_of_renderable_columns(self, exp_results_list):
137
        return reduce(lambda i, j: i.union(j),
138
                      [set(self.results_handler.get_all_columns(x, self.results_handler.DEFAULT_COLUMNS)) for x in
139
                       exp_results_list])
140
141
    def determine_maximal_set_of_renderable_columns_debug(self, exp_results_list):
142
        res = set()
143
        for exp_res in exp_results_list:
144
            colums = self.results_handler.get_all_columns(exp_res, self.results_handler.DEFAULT_COLUMNS)
145
            logger.debug("Model: {}, columns: [{}]".format(exp_res.scalars.model_label, ', '.join(sorted(colums))))
146
            c = set(colums)
147
            assert len(colums) == len(c)
148
            res = res.union(set(colums))
149
        return res
150
        # return reduce(lambda i, j: i.union(j),
151
        #               [set(self.results_handler.get_all_columns(x, self.results_handler.DEFAULT_COLUMNS)) for x in
152
        #                exp_results_list])
153
154
155
    def _compute_rows(self):
156
        self._model_labels, values_lists = self._get_labels_n_values()
157
        return [self._to_row(y[0], y[1]) for y in zip(self._model_labels, [self._to_list_of_strings(x) for x in values_lists])]
158
159
    def _get_labels_n_values(self):
160
        """Call this method to get a list of model labels and a list of lists of reportable values that correspond to each label
161
        Fitness_computer finds the maximum values per eligible column definition that need to be highlighted."""
162
        if len(self.columns_to_render) < 1:
163
            raise RuntimeError("No valid columns to compute")
164
        if self._metric != 'alphabetical':
165
            self.fitness_computer.__init__(single_metric=self._metric, column_definitions=self.columns_to_render)
166
            try:
167
                return [list(t) for t in zip(*sorted(zip(self._model_labels, self._results_value_vectors), key=lambda y: self.fitness_computer(y[1]), reverse=True))]
168
            except IndexError as e:
169
                raise IndexError("Error: Probably no vectors (one per model holding the columns/metric values to report) computed: [{}]. ModelsL [{}]".format(', '.join(str(_) for _ in self._results_value_vectors), ', '.join(x.scalars.model_label for x in self.exp_results)))
170
        return self._model_labels, [self.fitness_computer.pass_vector(x) for x in self._results_value_vectors]
171
172
    ########## STRING OPERATIONS ##########
173
    def _to_row(self, model_label, strings_list):
174
        return '{}{}{} {}'.format(model_label,
175
                                  ' '*(self._max_label_len-len(model_label)),
176
                                  self._label_separator,
177
                                  ' '.join(['{}{}'.format(x[1], ' '*(self._max_col_lens[x[0]] - self._length(x[1]))) for x in enumerate(strings_list)]))
178
179
    def _to_list_of_strings(self, values_list):
180
        return [self._to_string(x[0], x[1]) for x in zip(values_list, self.columns_to_render)]
181
182
    def _to_string(self, value, column_definition):
183
        _ = '-'
184
        if value is not None:
185
            _ = self.results_handler.stringnify(column_definition, value)
186
            self._max_col_lens[self.columns_to_render.index(column_definition)] = max(self._max_col_lens[self.columns_to_render.index(column_definition)], len(_))
187
        if column_definition in self.fitness_computer.best and value == self.fitness_computer.best[column_definition]:
188
            return self.highlight_pre_fix + _ + self.highlight_post_fix
189
        return _
190
191
    def _length(self, a_string):
192
        _ = re.search(r'm(\d+(?:\.\d+)?)', a_string)  # !! this regex assumes that the string represents a number (eg will fail if input belongs to 'regularizers' column)
193
        if _: # if string is wrapped arround rendering decorators
194
            return len(_.group(1))
195
        return len(a_string)
196
197
    ########## EXTRACTION ##########
198
    @property
199
    def _results_value_vectors(self):
200
        return [self._extract_all(x) for x in self.exp_results]
201
202
    def _extract(self, exp_res, column):
203
        # try:
204
        return self.results_handler.extract(exp_res, column, 'last')
205
        # except KeyError as e:
206
        #     return None
207
                # if column.startswith('topic-kernel') or column.startswith('top-tokens') or column.startswith('background-tokens-ratio'):
208
        #     if column in exp_res.tracked
209
210
    def _extract_all(self, exp_results):  # get a list (vector) of extracted values; it shall contain integers, floats, Nones
211
        # (for metrics not tracked for the specific model) a single string for representing the regularization specifications and nan for the 'sparsity-phi-i' metric
212
        _ = [self._extract(exp_results, x) for x in self.columns_to_render]
213
        if len(_) < 1:
214
            raise RuntimeError("Empty vector!: {}. Failed to extract all [{}]".format(_, ', '.join(sorted(str(x) for x in self.columns_to_render))))
215
        return _
216
217
    def _get_renderable(self, allowed_renderable, columns):
218
        """
219
        Call this method to get the list of valid renderable columns and the list of invalid ones. The renderable columns
220
        are inferred from the selected 'columns' and the 'allowed' ones.\n
221
        :param list allowed_renderable:
222
        :param list columns:
223
        :return: 1st list: renderable inferred, 2nd list list: invalid requested columns to render
224
        :rtype: tuple of two lists
225
        """
226
        return [[_f for _f in reduce(lambda i,j: i+j, x) if _f]
227
                for x in zip(*[list(z)
228
                               for z in [self._build_renderable(y, allowed_renderable)
229
                                         for y in columns]])]
230
231
    def _build_renderable(self, requested_column, allowed_renderable):
232
        if requested_column in allowed_renderable:
233
            return [requested_column], [None]
234
        elif requested_column in self.results_handler.DYNAMIC_COLUMNS:
235
            return sorted([_ for _ in allowed_renderable if _.startswith(requested_column)]), [None]
236
        elif requested_column in self.results_handler.DEFAULT_COLUMNS:  # if c is one of the columns that map to exactly one column to render; ie 'perplexity'
237
            return [requested_column], [None]
238
        else: # requested column is invalid: is not on of the allowed renderable columns
239
            logger.warning("One of the discovered models requires to render the '{}' column but it is not within the infered allowed columns [{}], nor in the DYNAMIC [{}] or DEFAULT (see ResultsHandler) [{}].".format(
240
                requested_column,
241
                ', '.join(sorted(allowed_renderable)),
242
                ', '.join(sorted(self.results_handler.DYNAMIC_COLUMNS)),
243
                ', '.join(sorted(self.results_handler.DEFAULT_COLUMNS))
244
            ))
245
            return [None], [requested_column]
246
247
    ########## STATIC ##########
248
249
250
    @staticmethod
251
    def _get_invalid_column_definitions(column_defs, allowed_renderable):
252
        return [_ for _ in column_defs if _ not in allowed_renderable]
253
254
    @staticmethod
255
    def _get_label(json_path):
256
        try:
257
            return re.search(r'/([\w\-\.\+@]+)\.json$', json_path).group(1)
258
        except AttributeError as e:
259
            print('PATH', json_path)
260
            raise e
261
262
    @staticmethod
263
    def _get_hash_key(column_definition):
264
        return '-'.join([_ for _ in column_definition.split('-') if ModelReporter._is_token(_)])
265
266
    @staticmethod
267
    def _parse_column_definition(definition):
268
        return [list([_f for _f in y if _f]) for y in zip(*[(x, None) if ModelReporter._is_token(x) else (None, x) for x in definition.split('-')])]
269
270
    @staticmethod
271
    def _is_token(definition_element):
272
        try:
273
            _ = float(definition_element)
274
            return False
275
        except ValueError:
276
            if definition_element[0] == '@' or len(definition_element) == 1:
277
                return False
278
            return True
279
280
281
class InvalidColumnsException(Exception): pass
282
class InvalidMetricException(Exception): pass
283