1
|
|
|
import os |
2
|
|
|
import re |
3
|
|
|
import warnings |
4
|
|
|
from glob import glob |
5
|
|
|
from functools import reduce |
6
|
|
|
|
7
|
|
|
import attr |
8
|
|
|
|
9
|
|
|
from topic_modeling_toolkit.results import ExperimentalResults |
10
|
|
|
|
11
|
|
|
from .fitness import FitnessFunction |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
KERNEL_SUB_ENTITIES = ('coherence', 'contrast', 'purity', 'size') |
15
|
|
|
|
16
|
|
|
MAX_DECIMALS = 2 # this consant should agree with the patm.modeling.experiment.Experiment.MAX_DECIMALS |
17
|
|
|
|
18
|
|
|
def _get_kernel_sub_hash(entity): |
19
|
|
|
assert entity in KERNEL_SUB_ENTITIES |
20
|
|
|
return {'kernel-'+entity: {'scalar-extractor': lambda x,y: getattr(getattr(x.tracked, 'kernel'+y[2:]).average, entity).last if hasattr(x.tracked, 'kernel'+y[2:]) else None, |
21
|
|
|
'list-extractor': lambda x, y: getattr(getattr(x.tracked, 'kernel' + y[2:]).average, entity).all if hasattr(x.tracked, 'kernel' + y[2:]) else None, |
22
|
|
|
'column-title': lambda x: 'k'+entity[:3:2]+'.'+str(x)[2:], |
23
|
|
|
'to-string': '{:.4f}', |
24
|
|
|
'definitions': lambda x: ['kernel-{}-{}'.format(entity, y) for y in x.tracked.kernel_thresholds]}} |
25
|
|
|
|
26
|
|
|
COLUMNS_HASH = { |
27
|
|
|
'nb-topics': {'scalar-extractor': lambda x: x.scalars.nb_topics, |
28
|
|
|
'column-title': lambda: 'tpcs'}, |
29
|
|
|
'collection-passes': {'scalar-extractor': lambda x: x.scalars.dataset_iterations, |
30
|
|
|
'column-title': lambda: 'col-i'}, |
31
|
|
|
'document-passes': {'scalar-extractor': lambda x: x.scalars.document_passes, |
32
|
|
|
'column-title': lambda: 'doc-i'}, |
33
|
|
|
'total-phi-updates': {'scalar-extractor': lambda x: x.scalars.dataset_iterations * x.scalars.document_passes, |
34
|
|
|
'column-title': lambda: 'phi-u'}, |
35
|
|
|
'perplexity': {'scalar-extractor': lambda x: x.tracked.perplexity.last, |
36
|
|
|
'list-extractor': lambda x: x.tracked.perplexity.all, |
37
|
|
|
'column-title': lambda: 'prpl', |
38
|
|
|
'to-string': '{:.1f}'}, |
39
|
|
|
'top-tokens-coherence': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'top'+str(y)).average_coherence.last if hasattr(x.tracked, 'top'+str(y)) else None, |
40
|
|
|
'list-extractor': lambda x,y: getattr(x.tracked, 'top'+str(y)).average_coherence.all if hasattr(x.tracked, 'top'+str(y)) else None, |
41
|
|
|
'column-title': lambda x: 'top'+str(x)+'ch', |
42
|
|
|
'to-string': '{:.4f}', |
43
|
|
|
'definitions': lambda x: ['top-tokens-coherence-'+str(y) for y in x.tracked.top_tokens_cardinalities]}, |
44
|
|
|
'sparsity-phi': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'sparsity_phi_'+y).last if hasattr(x.tracked, 'sparsity_phi_'+y) else None, |
45
|
|
|
'list-extractor': lambda x,y: getattr(x.tracked, 'sparsity_phi_'+y).all if hasattr(x.tracked, 'sparsity_phi_'+y) else None, |
46
|
|
|
'column-title': lambda y: 'spp@'+y, |
47
|
|
|
'to-string': '{:.2f}', |
48
|
|
|
'definitions': lambda x: ['sparsity-phi-{}'.format(y) for y in x.tracked.modalities_initials]}, |
49
|
|
|
'sparsity-theta': {'scalar-extractor': lambda x: x.tracked.sparsity_theta.last, |
50
|
|
|
'list-extractor': lambda x: x.tracked.sparsity_theta.all, |
51
|
|
|
'column-title': lambda: 'spt', |
52
|
|
|
'to-string': '{:.2f}'}, |
53
|
|
|
'background-tokens-ratio': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]).last if hasattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]) else None, |
54
|
|
|
# 'list-extractor': lambda x,y: getattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]).all if hasattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]) else None, |
55
|
|
|
'list-extractor': lambda x,y: getattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]).all if hasattr(x.tracked, 'background_tokens_ratio_'+str(y)[2:]) else None, |
56
|
|
|
'column-title': lambda x: 'btr.'+str(x)[2:], |
57
|
|
|
'to-string': '{:.2f}', |
58
|
|
|
'definitions': lambda x: ['background-tokens-ratio-{}'.format(y[:4]) for y in x.tracked.background_tokens_thresholds]}, |
59
|
|
|
'regularizers': {'scalar-extractor': lambda x: '[{}]'.format(', '.join(map(regularizers_format, x.regularizers))), |
60
|
|
|
'column-title': lambda: 'regs',}, |
61
|
|
|
'kernel-size': {'scalar-extractor': lambda x,y: getattr(x.tracked, 'kernel'+y[2:]).average.size.last if hasattr(x.tracked, 'kernel'+y[2:]) else None, |
62
|
|
|
'list-extractor': lambda x, y: getattr(x.tracked, 'kernel' + y[2:]).average.size.all if hasattr(x.tracked, 'kernel' + y[2:]) else None, |
63
|
|
|
'column-title': lambda x: 'k'+'size'[:3:2]+'.'+str(x)[2:], |
64
|
|
|
'to-string': '{:.1f}', |
65
|
|
|
'definitions': lambda x: ['kernel-size-{}'.format(y) for y in x.tracked.kernel_thresholds]} |
66
|
|
|
} |
67
|
|
|
|
68
|
|
|
def regularizers_format(reg_def_string): |
69
|
|
|
return reg_def_string |
70
|
|
|
# try: |
71
|
|
|
# return '-'.join(re.findall(r"(?:^|-)(\w{1,4})\w*", reg_def_string[:reg_def_string.index("|")])) + reg_def_string[reg_def_string.index("|"):] |
72
|
|
|
# except ValueError as e: |
73
|
|
|
# print(e) |
74
|
|
|
# return reg_def_string |
75
|
|
|
|
76
|
|
|
COLUMNS_HASH = reduce(lambda x, y: dict(y, **x), [COLUMNS_HASH] + [_get_kernel_sub_hash(z) for z in KERNEL_SUB_ENTITIES]) |
77
|
|
|
|
78
|
|
|
|
79
|
|
|
class ResultsHandler(object): |
80
|
|
|
_list_selector_hash = {str: lambda x: x[0] if x[1] == 'all' else None, |
81
|
|
|
range: lambda x: [x[0][_] for _ in x[1]], |
82
|
|
|
int: lambda x: x[0][:x[1]], |
83
|
|
|
list: lambda x: [x[0][_] for _ in x[1]]} |
84
|
|
|
_QUANTITY_2_EXTRACTOR_KEY = {'last': 'scalar', 'all': 'list'} |
85
|
|
|
DYNAMIC_COLUMNS = ['kernel-size', 'kernel-coherence', 'kernel-contrast', 'kernel-purity', 'top-tokens-coherence', 'sparsity-phi', 'background-tokens-ratio'] |
86
|
|
|
DEFAULT_COLUMNS = ['nb-topics', 'collection-passes', 'document-passes', 'total-phi-updates', 'perplexity'] +\ |
87
|
|
|
DYNAMIC_COLUMNS[:-1] + ['sparsity-theta'] + [DYNAMIC_COLUMNS[-1]] + ['regularizers'] |
88
|
|
|
|
89
|
|
|
def __init__(self, collections_root_path, results_dir_name='results'): |
90
|
|
|
self._collections_root = collections_root_path |
91
|
|
|
self._results_dir_name = results_dir_name |
92
|
|
|
self._results_hash = {} |
93
|
|
|
self._fitness_function_hash = {} |
94
|
|
|
self._list_selector = None |
95
|
|
|
|
96
|
|
|
def get_experimental_results(self, collection_name, sort='', selection='all'): |
97
|
|
|
""" |
98
|
|
|
Call this method to get a list of experimental result objects from topic models trained on the given collection.\n |
99
|
|
|
:param str collection_name: |
100
|
|
|
:param str sort: if None the experimental results are obtained alphabetically on their json path |
101
|
|
|
:param str or range or int or list selection: whether to select a subset of the experimental results fron the given collection\n |
102
|
|
|
- if selection == 'all', returns every experimental results object "extracted" from the jsons |
103
|
|
|
- if type(selection) == range, returns a "slice" of the experimental results based on the range |
104
|
|
|
- if type(selection) == int, returns the first n experimental results |
105
|
|
|
- if type(selection) == list, then it represents specific indices to sample the list of experimental results from |
106
|
|
|
:return: the ExperimentalResults objects |
107
|
|
|
:rtype: list |
108
|
|
|
""" |
109
|
|
|
result_paths = glob('{}/*.json'.format(os.path.join(self._collections_root, collection_name, self._results_dir_name))) |
110
|
|
|
if type(selection) == list and all(type(x) == str for x in selection): # if input list contains model labels |
111
|
|
|
e = self._get_experimental_results([_ for _ in result_paths if re.search(r'/(?:{})\.json'.format('|'.join(selection)), _)]) |
112
|
|
|
if len(e) != len(selection): |
113
|
|
|
raise ValueError("len1 = {}, len2 = {}\nseq1 = {}\nseq2 = {}".format(len(e), len(selection), [_.scalars.model_label for _ in e], selection)) |
114
|
|
|
return e |
115
|
|
|
self._list_selector = lambda y: ResultsHandler._list_selector_hash[type(selection)]([y, selection]) |
116
|
|
|
r = self._get_experimental_results(result_paths, metric_sorter=self._get_metric(sort)) |
117
|
|
|
|
118
|
|
|
assert len(result_paths) == len(r) |
119
|
|
|
return self._list_selector(r) |
120
|
|
|
|
121
|
|
|
def _get_experimental_results(self, results_paths, metric_sorter=None): |
122
|
|
|
# print('_get_experimental_results.metric_sorter: {}'.format(metric_sorter)) |
123
|
|
|
if metric_sorter is None: |
124
|
|
|
return [self._process_result_path(_) for _ in sorted(results_paths)] |
125
|
|
|
return metric_sorter([self._process_result_path(x) for x in results_paths]) |
126
|
|
|
# sorted([self._process_result_path(x) for x in results_paths], key=metric_sorter, reverse=True) |
127
|
|
|
# sorter = ResultsSorter.from_function(metric_sorter) |
128
|
|
|
# return sorter([self._process_result_path(x) for x in results_paths]) |
129
|
|
|
|
130
|
|
|
# if metric_sorter: |
131
|
|
|
# assert hasattr(metric_sorter, '__call__') |
132
|
|
|
# print(' Metric function:', metric_sorter.__name__) |
133
|
|
|
# |
134
|
|
|
# |
135
|
|
|
|
136
|
|
|
def _process_result_path(self, result_path): |
137
|
|
|
if result_path not in self._results_hash: |
138
|
|
|
self._results_hash[result_path] = ExperimentalResults.create_from_json_file(result_path) |
139
|
|
|
return self._results_hash[result_path] |
140
|
|
|
|
141
|
|
|
def _get_metric(self, metric): |
142
|
|
|
if not metric: |
143
|
|
|
return None |
144
|
|
|
if metric not in self._fitness_function_hash: |
145
|
|
|
self._fitness_function_hash[metric] = FitnessFunction.single_metric(metric) |
146
|
|
|
return MetricSorter(metric, lambda x: self._fitness_function_hash[metric].compute([ResultsHandler.extract(x, metric, 'last')])) |
147
|
|
|
|
148
|
|
|
@staticmethod |
149
|
|
|
def get_titles(column_definitions): |
150
|
|
|
return [ResultsHandler.get_abbreviation(x) for x in column_definitions] |
151
|
|
|
|
152
|
|
|
@staticmethod |
153
|
|
|
def get_abbreviation(definition): |
154
|
|
|
tokens, parameters = ResultsHandler._parse_column_definition(definition) |
155
|
|
|
return COLUMNS_HASH['-'.join(tokens)]['column-title'](*parameters) |
156
|
|
|
|
157
|
|
|
@staticmethod |
158
|
|
|
def stringnify(column, value): |
159
|
|
|
""" |
160
|
|
|
:param str column: key or definition; example values: 'perplexity', 'kernel-coherence', 'kernel-coherence-0.80' |
161
|
|
|
:param value: |
162
|
|
|
:return: |
163
|
|
|
:rtype str |
164
|
|
|
""" |
165
|
|
|
return COLUMNS_HASH.get(column, COLUMNS_HASH[ResultsHandler._get_hash_key(column)]).get('to-string', '{}').format(value) |
166
|
|
|
|
167
|
|
|
# @staticmethod |
168
|
|
|
# def get_tau_trajectory(exp_results, matrix_name): |
169
|
|
|
# """ |
170
|
|
|
# |
171
|
|
|
# :param results.experimental_results.ExperimentalResults exp_results: |
172
|
|
|
# :param matrix_name: |
173
|
|
|
# :return: |
174
|
|
|
# """ |
175
|
|
|
# return getattr(exp_results.tracked.tau_trajectories, matrix_name).all |
176
|
|
|
|
177
|
|
|
@staticmethod |
178
|
|
|
def extract(exp_results, column_definition, quantity): |
179
|
|
|
""" |
180
|
|
|
Call this method to query the given experimental results object about a specific metric. Supports requesting all |
181
|
|
|
values tracked along the training process. |
182
|
|
|
:param results.experimental_results.ExperimentalResults exp_results: |
183
|
|
|
:param str column_definition: |
184
|
|
|
:param str quantity: must be one of {'last', 'all'} |
185
|
|
|
:return: |
186
|
|
|
""" |
187
|
|
|
tokens, parameters = ResultsHandler._parse_column_definition(column_definition) |
188
|
|
|
return COLUMNS_HASH['-'.join(tokens)][ResultsHandler._QUANTITY_2_EXTRACTOR_KEY[quantity] + '-extractor'](*list([exp_results] + parameters)) |
189
|
|
|
|
190
|
|
|
@staticmethod |
191
|
|
|
def get_all_columns(exp_results, requested_entities): |
192
|
|
|
return reduce(lambda i, j: i + j, |
193
|
|
|
[COLUMNS_HASH[x]['definitions'](exp_results) if x in ResultsHandler.DYNAMIC_COLUMNS else [x] for x in requested_entities]) |
194
|
|
|
|
195
|
|
|
###### UTILITY FUNCTIONS ###### |
196
|
|
|
@staticmethod |
197
|
|
|
def _parse_column_definition(definition): |
198
|
|
|
return [list([_f for _f in y if _f]) for y in zip(*[(x, None) if ResultsHandler._is_token(x) else (None, x) for x in definition.split('-')])] |
199
|
|
|
|
200
|
|
|
@staticmethod |
201
|
|
|
def _get_hash_key(column_definition): |
202
|
|
|
return '-'.join([_ for _ in column_definition.split('-') if ResultsHandler._is_token(_)]) |
203
|
|
|
|
204
|
|
|
@staticmethod |
205
|
|
|
def _is_token(definition_element): |
206
|
|
|
try: |
207
|
|
|
_ = float(definition_element) |
208
|
|
|
return False |
209
|
|
|
except ValueError: |
210
|
|
|
if definition_element[0] == '@' or len(definition_element) == 1: |
211
|
|
|
return False |
212
|
|
|
return True |
213
|
|
|
|
214
|
|
|
@staticmethod |
215
|
|
|
def _label_selection(labels, experimental_results_list): |
216
|
|
|
"""Returns the indices of the input labels based on the input experimental results list""" |
217
|
|
|
model_labels = [x.scalars.model_label for x in experimental_results_list] |
218
|
|
|
return [experimental_results_list.index(model_labels.index(l)) for l in labels if l in model_labels] |
219
|
|
|
|
220
|
|
|
|
221
|
|
|
########################## |
222
|
|
|
def _build_sorter(instance, attribute, value): |
223
|
|
|
if not hasattr(value, '__call__'): |
224
|
|
|
raise TypeError("Second constructor argument should be a callable object, in case the first is 'alphabetical'") |
225
|
|
|
# this is sorting from 'bigger' to 'smaller' (independently of how <,>, operators have been defined) |
226
|
|
|
instance.experimental_result_sorter = lambda exp_res_objs_list: sorted(exp_res_objs_list, key=value, reverse=True) |
227
|
|
|
|
228
|
|
|
@attr.s(cmp=True, repr=True, str=True) |
229
|
|
|
class MetricSorter(object): |
230
|
|
|
name = attr.ib(init=True, converter=str, cmp=True, repr=True) |
231
|
|
|
experimental_result_sorter = attr.ib(init=True, validator=_build_sorter, cmp=True, repr=True) |
232
|
|
|
|
233
|
|
|
@classmethod |
234
|
|
|
def from_function(cls, function): |
235
|
|
|
return ResultsSorter(function.__name__, function) |
|
|
|
|
236
|
|
|
def __call__(self, *args, **kwargs): |
237
|
|
|
return self.experimental_result_sorter(args[0]) |
238
|
|
|
############################ |
239
|
|
|
|
240
|
|
|
|
241
|
|
|
if __name__ == '__main__': |
242
|
|
|
ms = ModelSelector(collection_results_dir_path='/data/thesis/data/collections/dd/results') |
|
|
|
|
243
|
|
|
|