1
|
|
|
import re |
2
|
|
|
import os |
3
|
|
|
import sys |
4
|
|
|
from glob import glob |
5
|
|
|
from collections import Iterable |
6
|
|
|
|
7
|
|
|
from .fitness import FitnessCalculator |
8
|
|
|
from .model_selection import ResultsHandler |
9
|
|
|
from functools import reduce |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
import logging |
13
|
|
|
logger = logging.getLogger(__name__) |
14
|
|
|
logger.setLevel('INFO') |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
class ModelReporter: |
18
|
|
|
BOLD = '\033[1m' |
19
|
|
|
UNDERLINE = '\033[4m' |
20
|
|
|
ENDC = '\033[0m' |
21
|
|
|
|
22
|
|
|
def __init__(self, collections_root_path, results_dir_name='results'): |
23
|
|
|
self._collections_dir = collections_root_path |
24
|
|
|
self.results_handler = ResultsHandler(self._collections_dir, results_dir_name=results_dir_name) |
25
|
|
|
self._results_dir_name = results_dir_name |
26
|
|
|
self._label_separator = ':' |
27
|
|
|
self._columns_to_render = [] |
28
|
|
|
self._column_keys_to_highlight = ['perplexity', 'kernel-coherence', 'kernel-purity', 'kernel-contrast', 'top-tokens-coherence', |
29
|
|
|
'sparsity-phi', 'sparsity-theta', 'background-tokens-ratio'] |
30
|
|
|
self.highlight_pre_fix = ModelReporter.UNDERLINE |
31
|
|
|
self.highlight_post_fix = ModelReporter.ENDC |
32
|
|
|
self.fitness_computer = FitnessCalculator() |
33
|
|
|
self._max_label_len = 0 |
34
|
|
|
self._max_col_lens = [] |
35
|
|
|
|
36
|
|
|
def get_formatted_string(self, collection_name, columns=None, metric='', verbose=True): |
37
|
|
|
""" |
38
|
|
|
:param str collection_name: |
39
|
|
|
:param list columns: |
40
|
|
|
:param str metric: |
41
|
|
|
:param bool verbose: |
42
|
|
|
:return: |
43
|
|
|
:rtype: str |
44
|
|
|
""" |
45
|
|
|
self._initialize(collection_name, columns=columns, metric=metric, verbose=verbose) |
46
|
|
|
body = '\n'.join(self._compute_rows()) |
47
|
|
|
head = '{}{} {} {}'.format(' '*self._max_label_len, |
48
|
|
|
' '*len(self._label_separator), |
49
|
|
|
' '.join(['{}{}'.format(x[1], ' '*(self._max_col_lens[x[0]] - len(x[1]))) for x in enumerate(self._columns_titles[:-1])]), |
50
|
|
|
self._columns_titles[-1]) |
51
|
|
|
return head + '\n' + body |
52
|
|
|
|
53
|
|
|
@property |
54
|
|
|
def exp_results(self): |
55
|
|
|
""" |
56
|
|
|
:rtype: list of results.experimental_results.ExperimentalResults |
57
|
|
|
""" |
58
|
|
|
return self.results_handler.get_experimental_results(self._collection_name, selection='all') |
59
|
|
|
|
60
|
|
|
def _initialize(self, collection_name, columns=None, metric='', verbose=False): |
61
|
|
|
self._collection_name = collection_name |
62
|
|
|
self._result_paths = glob('{}/*.json'.format(os.path.join(self._collections_dir, collection_name, self._results_dir_name))) |
63
|
|
|
self._model_labels = [ModelReporter._get_label(x) for x in self._result_paths] |
64
|
|
|
if not self._model_labels: |
65
|
|
|
raise RuntimeError("Either wrong dataset label '{}' was given or the collection/dataset has no trained models. Dataset root: '{}', contents: [{}]. results contents: [{}]".format( |
66
|
|
|
collection_name, os.path.join(self._collections_dir, collection_name), ', '.join(os.path.basename(x) for x in os.listdir(os.path.join(self._collections_dir, collection_name))), |
67
|
|
|
', '.join(os.path.basename(x) for x in os.listdir(os.path.join(self._collections_dir, collection_name, 'results'))))) |
68
|
|
|
self._max_label_len = max([len(x) for x in self._model_labels]) |
69
|
|
|
self._columns_to_render, self._columns_failed = [], [] |
70
|
|
|
self.maximal_requested_columns = self.determine_maximal_set_of_renderable_columns_debug(self.exp_results) |
71
|
|
|
|
72
|
|
|
# for res in self.exp_results: |
73
|
|
|
# if set(self._containing_column(res)) != set(self.maximal_requested_columns): |
74
|
|
|
# logger.warning("This fails in unittesting because trackables differ per model results. Maximal discovered: [{}] diff current model = {}, label: {}".format( |
75
|
|
|
# ', '.join(sorted(list(set(self.maximal_requested_columns)))), ', '.join(sorted([_ for _ in self.maximal_requested_columns if _ not in self._containing_column(res)])), res.scalars.model_label)) |
76
|
|
|
|
77
|
|
|
self._maximal_renderable_columns = self._get_maximal_renderable_columns() |
78
|
|
|
|
79
|
|
|
if not columns: |
80
|
|
|
self.columns_to_render = self._maximal_renderable_columns |
81
|
|
|
else: |
82
|
|
|
self.columns_to_render, self._columns_failed = self._get_renderable(self._maximal_renderable_columns, columns) |
83
|
|
|
if metric != 'alphabetical' and metric not in self.columns_to_render: |
84
|
|
|
raise InvalidMetricException( "Metric '{}' should be either 'alphabetical' or within the recognized ones [{}]".format(metric, ', '.join(self.columns_to_render))) |
85
|
|
|
self._metric = metric |
86
|
|
|
print("Input metric to sort on: '{}'".format(self._metric)) |
87
|
|
|
if verbose: |
88
|
|
|
print('Using: [{}]'.format(', '.join(self.columns_to_render))) |
89
|
|
|
print('Ommiting: [{}]'.format(', '.join({_ for _ in self._maximal_renderable_columns if _ not in self.columns_to_render}))) |
90
|
|
|
print('Failed: [{}]'.format(', '.join(self._columns_failed))) |
91
|
|
|
self._columns_titles = self.results_handler.get_titles(self.columns_to_render) |
92
|
|
|
self._max_col_lens = [len(x) for x in self._columns_titles] |
93
|
|
|
self.fitness_computer.highlightable_columns = [_ for _ in self.columns_to_render if ModelReporter._get_hash_key(_) in self._column_keys_to_highlight] |
94
|
|
|
|
95
|
|
|
@property |
96
|
|
|
def columns_to_render(self): |
97
|
|
|
return self._columns_to_render |
98
|
|
|
|
99
|
|
|
@columns_to_render.setter |
100
|
|
|
def columns_to_render(self, column_definitions): |
101
|
|
|
if not isinstance(column_definitions, Iterable): |
102
|
|
|
raise InvalidColumnsException( |
103
|
|
|
"Input column definitions are of type '{}' instead of iterable".format(type(column_definitions))) |
104
|
|
|
if not column_definitions: |
105
|
|
|
raise InvalidColumnsException('Input column definitions evaluates to None') |
106
|
|
|
invalid_columns = ModelReporter._get_invalid_column_definitions(column_definitions, |
107
|
|
|
self._maximal_renderable_columns) |
108
|
|
|
if invalid_columns: |
109
|
|
|
raise InvalidColumnsException( |
110
|
|
|
'Input column definitions [{}] are not valid'.format(', '.join(invalid_columns))) |
111
|
|
|
self._columns_to_render = column_definitions |
112
|
|
|
|
113
|
|
|
########## COLUMNS DEFINITIONS ########## |
114
|
|
|
def _get_maximal_renderable_columns(self): |
115
|
|
|
"""Call this method to get a list of all the inferred columns allowed to render.""" |
116
|
|
|
_ = ModelReporter._get_column_definitions(self.results_handler.DEFAULT_COLUMNS, |
117
|
|
|
self.maximal_requested_columns) |
118
|
|
|
if len(_) != len(self.maximal_requested_columns): |
119
|
|
|
raise RuntimeError("Discovered columns (from results): [{}]. Computed: [{}. Missmatch with supported DYNAMIC [{}] and DEFAULT [{}] columns.".format( |
120
|
|
|
', '.join(sorted(self.maximal_requested_columns)), |
121
|
|
|
', '.join(sorted(_)), |
122
|
|
|
', '.join(sorted(self.results_handler.DYNAMIC_COLUMNS)), |
123
|
|
|
', '.join(sorted(self.results_handler.DEFAULT_COLUMNS)) |
124
|
|
|
)) |
125
|
|
|
return _ |
126
|
|
|
|
127
|
|
|
@staticmethod |
128
|
|
|
def _get_column_definitions(supported_columns, requested_column_definitions): |
129
|
|
|
"""Given a list of allowed column definitions, returns a sublist of it based on the selected supported_columns. The returned list is ordered |
130
|
|
|
based on the given supported_columns.""" |
131
|
|
|
return reduce(lambda i,j: i+j, [sorted([_ for _ in requested_column_definitions if _.startswith(x)]) for x in supported_columns]) |
132
|
|
|
|
133
|
|
|
def _containing_column(self, exp_res_obj): |
134
|
|
|
return self.results_handler.get_all_columns(exp_res_obj, self.results_handler.DEFAULT_COLUMNS) |
135
|
|
|
|
136
|
|
|
def determine_maximal_set_of_renderable_columns(self, exp_results_list): |
137
|
|
|
return reduce(lambda i, j: i.union(j), |
138
|
|
|
[set(self.results_handler.get_all_columns(x, self.results_handler.DEFAULT_COLUMNS)) for x in |
139
|
|
|
exp_results_list]) |
140
|
|
|
|
141
|
|
|
def determine_maximal_set_of_renderable_columns_debug(self, exp_results_list): |
142
|
|
|
res = set() |
143
|
|
|
for exp_res in exp_results_list: |
144
|
|
|
colums = self.results_handler.get_all_columns(exp_res, self.results_handler.DEFAULT_COLUMNS) |
145
|
|
|
logger.debug("Model: {}, columns: [{}]".format(exp_res.scalars.model_label, ', '.join(sorted(colums)))) |
146
|
|
|
c = set(colums) |
147
|
|
|
assert len(colums) == len(c) |
148
|
|
|
res = res.union(set(colums)) |
149
|
|
|
return res |
150
|
|
|
# return reduce(lambda i, j: i.union(j), |
151
|
|
|
# [set(self.results_handler.get_all_columns(x, self.results_handler.DEFAULT_COLUMNS)) for x in |
152
|
|
|
# exp_results_list]) |
153
|
|
|
|
154
|
|
|
|
155
|
|
|
def _compute_rows(self): |
156
|
|
|
self._model_labels, values_lists = self._get_labels_n_values() |
157
|
|
|
return [self._to_row(y[0], y[1]) for y in zip(self._model_labels, [self._to_list_of_strings(x) for x in values_lists])] |
158
|
|
|
|
159
|
|
|
def _get_labels_n_values(self): |
160
|
|
|
"""Call this method to get a list of model labels and a list of lists of reportable values that correspond to each label |
161
|
|
|
Fitness_computer finds the maximum values per eligible column definition that need to be highlighted.""" |
162
|
|
|
if len(self.columns_to_render) < 1: |
163
|
|
|
raise RuntimeError("No valid columns to compute") |
164
|
|
|
if self._metric != 'alphabetical': |
165
|
|
|
self.fitness_computer.__init__(single_metric=self._metric, column_definitions=self.columns_to_render) |
166
|
|
|
try: |
167
|
|
|
return [list(t) for t in zip(*sorted(zip(self._model_labels, self._results_value_vectors), key=lambda y: self.fitness_computer(y[1]), reverse=True))] |
168
|
|
|
except IndexError as e: |
169
|
|
|
raise IndexError("Error: Probably no vectors (one per model holding the columns/metric values to report) computed: [{}]. ModelsL [{}]".format(', '.join(str(_) for _ in self._results_value_vectors), ', '.join(x.scalars.model_label for x in self.exp_results))) |
170
|
|
|
return self._model_labels, [self.fitness_computer.pass_vector(x) for x in self._results_value_vectors] |
171
|
|
|
|
172
|
|
|
########## STRING OPERATIONS ########## |
173
|
|
|
def _to_row(self, model_label, strings_list): |
174
|
|
|
return '{}{}{} {}'.format(model_label, |
175
|
|
|
' '*(self._max_label_len-len(model_label)), |
176
|
|
|
self._label_separator, |
177
|
|
|
' '.join(['{}{}'.format(x[1], ' '*(self._max_col_lens[x[0]] - self._length(x[1]))) for x in enumerate(strings_list)])) |
178
|
|
|
|
179
|
|
|
def _to_list_of_strings(self, values_list): |
180
|
|
|
return [self._to_string(x[0], x[1]) for x in zip(values_list, self.columns_to_render)] |
181
|
|
|
|
182
|
|
|
def _to_string(self, value, column_definition): |
183
|
|
|
_ = '-' |
184
|
|
|
if value is not None: |
185
|
|
|
_ = self.results_handler.stringnify(column_definition, value) |
186
|
|
|
self._max_col_lens[self.columns_to_render.index(column_definition)] = max(self._max_col_lens[self.columns_to_render.index(column_definition)], len(_)) |
187
|
|
|
if column_definition in self.fitness_computer.best and value == self.fitness_computer.best[column_definition]: |
188
|
|
|
return self.highlight_pre_fix + _ + self.highlight_post_fix |
189
|
|
|
return _ |
190
|
|
|
|
191
|
|
|
def _length(self, a_string): |
192
|
|
|
_ = re.search(r'm(\d+(?:\.\d+)?)', a_string) # !! this regex assumes that the string represents a number (eg will fail if input belongs to 'regularizers' column) |
193
|
|
|
if _: # if string is wrapped arround rendering decorators |
194
|
|
|
return len(_.group(1)) |
195
|
|
|
return len(a_string) |
196
|
|
|
|
197
|
|
|
########## EXTRACTION ########## |
198
|
|
|
@property |
199
|
|
|
def _results_value_vectors(self): |
200
|
|
|
return [self._extract_all(x) for x in self.exp_results] |
201
|
|
|
|
202
|
|
|
def _extract(self, exp_res, column): |
203
|
|
|
# try: |
204
|
|
|
return self.results_handler.extract(exp_res, column, 'last') |
205
|
|
|
# except KeyError as e: |
206
|
|
|
# return None |
207
|
|
|
# if column.startswith('topic-kernel') or column.startswith('top-tokens') or column.startswith('background-tokens-ratio'): |
208
|
|
|
# if column in exp_res.tracked |
209
|
|
|
|
210
|
|
|
def _extract_all(self, exp_results): # get a list (vector) of extracted values; it shall contain integers, floats, Nones |
211
|
|
|
# (for metrics not tracked for the specific model) a single string for representing the regularization specifications and nan for the 'sparsity-phi-i' metric |
212
|
|
|
_ = [self._extract(exp_results, x) for x in self.columns_to_render] |
213
|
|
|
if len(_) < 1: |
214
|
|
|
raise RuntimeError("Empty vector!: {}. Failed to extract all [{}]".format(_, ', '.join(sorted(str(x) for x in self.columns_to_render)))) |
215
|
|
|
return _ |
216
|
|
|
|
217
|
|
|
def _get_renderable(self, allowed_renderable, columns): |
218
|
|
|
""" |
219
|
|
|
Call this method to get the list of valid renderable columns and the list of invalid ones. The renderable columns |
220
|
|
|
are inferred from the selected 'columns' and the 'allowed' ones.\n |
221
|
|
|
:param list allowed_renderable: |
222
|
|
|
:param list columns: |
223
|
|
|
:return: 1st list: renderable inferred, 2nd list list: invalid requested columns to render |
224
|
|
|
:rtype: tuple of two lists |
225
|
|
|
""" |
226
|
|
|
return [[_f for _f in reduce(lambda i,j: i+j, x) if _f] |
227
|
|
|
for x in zip(*[list(z) |
228
|
|
|
for z in [self._build_renderable(y, allowed_renderable) |
229
|
|
|
for y in columns]])] |
230
|
|
|
|
231
|
|
|
def _build_renderable(self, requested_column, allowed_renderable): |
232
|
|
|
if requested_column in allowed_renderable: |
233
|
|
|
return [requested_column], [None] |
234
|
|
|
elif requested_column in self.results_handler.DYNAMIC_COLUMNS: |
235
|
|
|
return sorted([_ for _ in allowed_renderable if _.startswith(requested_column)]), [None] |
236
|
|
|
elif requested_column in self.results_handler.DEFAULT_COLUMNS: # if c is one of the columns that map to exactly one column to render; ie 'perplexity' |
237
|
|
|
return [requested_column], [None] |
238
|
|
|
else: # requested column is invalid: is not on of the allowed renderable columns |
239
|
|
|
logger.warning("One of the discovered models requires to render the '{}' column but it is not within the infered allowed columns [{}], nor in the DYNAMIC [{}] or DEFAULT (see ResultsHandler) [{}].".format( |
240
|
|
|
requested_column, |
241
|
|
|
', '.join(sorted(allowed_renderable)), |
242
|
|
|
', '.join(sorted(self.results_handler.DYNAMIC_COLUMNS)), |
243
|
|
|
', '.join(sorted(self.results_handler.DEFAULT_COLUMNS)) |
244
|
|
|
)) |
245
|
|
|
return [None], [requested_column] |
246
|
|
|
|
247
|
|
|
########## STATIC ########## |
248
|
|
|
|
249
|
|
|
|
250
|
|
|
@staticmethod |
251
|
|
|
def _get_invalid_column_definitions(column_defs, allowed_renderable): |
252
|
|
|
return [_ for _ in column_defs if _ not in allowed_renderable] |
253
|
|
|
|
254
|
|
|
@staticmethod |
255
|
|
|
def _get_label(json_path): |
256
|
|
|
try: |
257
|
|
|
return re.search(r'/([\w\-\.\+@]+)\.json$', json_path).group(1) |
258
|
|
|
except AttributeError as e: |
259
|
|
|
print('PATH', json_path) |
260
|
|
|
raise e |
261
|
|
|
|
262
|
|
|
@staticmethod |
263
|
|
|
def _get_hash_key(column_definition): |
264
|
|
|
return '-'.join([_ for _ in column_definition.split('-') if ModelReporter._is_token(_)]) |
265
|
|
|
|
266
|
|
|
@staticmethod |
267
|
|
|
def _parse_column_definition(definition): |
268
|
|
|
return [list([_f for _f in y if _f]) for y in zip(*[(x, None) if ModelReporter._is_token(x) else (None, x) for x in definition.split('-')])] |
269
|
|
|
|
270
|
|
|
@staticmethod |
271
|
|
|
def _is_token(definition_element): |
272
|
|
|
try: |
273
|
|
|
_ = float(definition_element) |
274
|
|
|
return False |
275
|
|
|
except ValueError: |
276
|
|
|
if definition_element[0] == '@' or len(definition_element) == 1: |
277
|
|
|
return False |
278
|
|
|
return True |
279
|
|
|
|
280
|
|
|
|
281
|
|
|
class InvalidColumnsException(Exception): pass |
282
|
|
|
class InvalidMetricException(Exception): pass |
283
|
|
|
|