|
1
|
|
|
import warnings |
|
2
|
|
|
from collections import OrderedDict, Counter |
|
3
|
|
|
import attr |
|
4
|
|
|
from functools import reduce |
|
5
|
|
|
import pprint |
|
6
|
|
|
from types import FunctionType |
|
7
|
|
|
from tqdm import tqdm |
|
8
|
|
|
from .parameters import ParameterGrid |
|
9
|
|
|
from ..modeling import TrainerFactory, Experiment |
|
10
|
|
|
from ..definitions import DEFAULT_CLASS_NAME, IDEOLOGY_CLASS_NAME # this is the name of the default modality. it is irrelevant to class lebels or document lcassification |
|
11
|
|
|
|
|
12
|
|
|
|
|
13
|
|
|
@attr.s |
|
14
|
|
|
class Versioning(object): |
|
15
|
|
|
objects = attr.ib(init=True, default=[], converter=Counter, repr=True, cmp=True) |
|
16
|
|
|
_max_digits_version = attr.ib(init=True, default=2, converter=int, repr=True, cmp=True) |
|
17
|
|
|
_joiner = '_v' |
|
18
|
|
|
|
|
19
|
|
|
def __call__(self, data): |
|
20
|
|
|
self.objects[data] += 1 |
|
21
|
|
|
if self.objects.get(data, 0) > 1: |
|
22
|
|
|
return '{}{}{}'.format(data, self._joiner, self._iter_prepend(self.objects[data]-1)) |
|
23
|
|
|
return data |
|
24
|
|
|
|
|
25
|
|
|
def _iter_prepend(self, int_num): |
|
26
|
|
|
nb_digits = len(str(int_num)) |
|
27
|
|
|
if nb_digits < self._max_digits_version: |
|
28
|
|
|
return '{}{}'.format((self._max_digits_version - nb_digits) * '0', int_num) |
|
29
|
|
|
if nb_digits == self._max_digits_version: |
|
30
|
|
|
return str(int_num) |
|
31
|
|
|
raise RuntimeError("More than 100 items are versioned. (but max_digit_length=2)") |
|
32
|
|
|
|
|
33
|
|
|
|
|
34
|
|
|
@attr.s |
|
35
|
|
|
class Tuner(object): |
|
36
|
|
|
dataset = attr.ib(init=True, repr=True) |
|
37
|
|
|
scores = attr.ib(init=True, default={ |
|
38
|
|
|
'perplexity': 'per', |
|
39
|
|
|
'sparsity-phi-@dc': 'sppd', |
|
40
|
|
|
'sparsity-phi-@ic': 'sppi', |
|
41
|
|
|
'sparsity-theta': 'spt', |
|
42
|
|
|
# 'topic-kernel-0.25': 'tk25', |
|
43
|
|
|
'topic-kernel-0.60': 'tk60', |
|
44
|
|
|
'topic-kernel-0.80': 'tk80', |
|
45
|
|
|
'top-tokens-10': 'top10', |
|
46
|
|
|
'top-tokens-100': 'top100', |
|
47
|
|
|
'background-tokens-ratio-0.3': 'btr3', |
|
48
|
|
|
'background-tokens-ratio-0.2': 'btr2' |
|
49
|
|
|
}) |
|
50
|
|
|
_training_parameters = attr.ib(init=True, default={}, converter=dict, repr=True) |
|
51
|
|
|
_reg_specs = attr.ib(init=True, default={}, converter=dict, repr=True) |
|
52
|
|
|
grid_searcher = attr.ib(init=True, default=None, repr=True) |
|
53
|
|
|
version = attr.ib(init=True, factory=Versioning, repr=True) |
|
54
|
|
|
|
|
55
|
|
|
_labeler = attr.ib(init=False, default=None) |
|
56
|
|
|
trainer = attr.ib(init=False, default=attr.Factory(lambda self: TrainerFactory().create_trainer(self.dataset, exploit_ideology_labels=True, force_new_batches=False), takes_self=True)) |
|
57
|
|
|
experiment = attr.ib(init=False, default=attr.Factory(lambda self: Experiment(self.dataset), takes_self=True)) |
|
58
|
|
|
|
|
59
|
|
|
def __attrs_post_init__(self): |
|
60
|
|
|
self.trainer.register(self.experiment) |
|
61
|
|
|
|
|
62
|
|
|
def __getitem__(self, item): |
|
63
|
|
|
if item == 'training': |
|
64
|
|
|
return self._training_parameters |
|
65
|
|
|
if item == 'regularization': |
|
66
|
|
|
return self._reg_specs |
|
67
|
|
|
raise KeyError |
|
68
|
|
|
|
|
69
|
|
|
@property |
|
70
|
|
|
def parameter_names(self): |
|
71
|
|
|
return self._training_parameters.parameter_names + self._reg_specs.parameter_names |
|
72
|
|
|
|
|
73
|
|
|
@property |
|
74
|
|
|
def constants(self): |
|
75
|
|
|
return self._training_parameters.steady + self._reg_specs.steady |
|
76
|
|
|
|
|
77
|
|
|
@property |
|
78
|
|
|
def explorables(self): |
|
79
|
|
|
return self._training_parameters.explorable + self._reg_specs.explorable |
|
80
|
|
|
|
|
81
|
|
|
@property |
|
82
|
|
|
def training_parameters(self): |
|
83
|
|
|
"""The mixture of steady parameters and prameters to tune on""" |
|
84
|
|
|
return self._training_parameters |
|
85
|
|
|
|
|
86
|
|
|
@training_parameters.setter |
|
87
|
|
|
def training_parameters(self, training_parameters): |
|
88
|
|
|
"""Provide a dict with the mixture of steady parameters and prameters to tune on""" |
|
89
|
|
|
self._training_parameters = ParametersMixture(training_parameters) |
|
90
|
|
|
|
|
91
|
|
|
@property |
|
92
|
|
|
def regularization_specs(self): |
|
93
|
|
|
"""The specifications according to which regularization components should be activated, initialized and potentially evolved (see tau trajectory) during training""" |
|
94
|
|
|
return self._reg_specs |
|
95
|
|
|
|
|
96
|
|
|
@regularization_specs.setter |
|
97
|
|
|
def regularization_specs(self, regularization_specs): |
|
98
|
|
|
self._reg_specs = RegularizationSpecifications(regularization_specs) |
|
99
|
|
|
|
|
100
|
|
|
@property |
|
101
|
|
|
def current_reg_specs(self): |
|
102
|
|
|
return {reg_type: |
|
103
|
|
|
{param_name: |
|
104
|
|
|
self._val('{}.{}'.format(reg_type, param_name)) |
|
105
|
|
|
for param_name in params_mixture.parameter_names} for reg_type, params_mixture in self._reg_specs} |
|
106
|
|
|
|
|
107
|
|
|
def _set_verbosity_level(self, input_verbose): |
|
108
|
|
|
try: |
|
109
|
|
|
self._vb = int(input_verbose) |
|
110
|
|
|
if self._vb < 0: |
|
111
|
|
|
self._vb = 0 |
|
112
|
|
|
elif 5 < self._vb: |
|
113
|
|
|
self._vb = 5 |
|
114
|
|
|
except ValueError: |
|
115
|
|
|
self._vb = 3 |
|
116
|
|
|
|
|
117
|
|
|
def tune(self, *args, **kwargs): |
|
118
|
|
|
self._set_verbosity_level(kwargs.get('verbose', 3)) |
|
119
|
|
|
|
|
120
|
|
|
if args: |
|
121
|
|
|
if len(args) > 0: |
|
122
|
|
|
self.training_parameters = args[0] |
|
123
|
|
|
if len(args) > 1: |
|
124
|
|
|
self.regularization_specs = args[1] |
|
125
|
|
|
|
|
126
|
|
|
self._labeler = LabelingDefinition.from_tuner(self, prefix=kwargs.get('prefix_label', ''), labeling_params=kwargs.get('labeling_params', False), |
|
127
|
|
|
append_static=kwargs.get('append_static', False), append_explorable=kwargs.get('append_explorables', True), |
|
128
|
|
|
preserve_order=kwargs.get('preserve_order', True), |
|
129
|
|
|
parameter_set=kwargs.get('parameter_set', 'training')) |
|
130
|
|
|
|
|
131
|
|
|
self.grid_searcher = ParameterGrid(self._training_parameters.parameter_spans + [span for _, reg_params_mixture in self._reg_specs for span in reg_params_mixture.parameter_spans]) |
|
132
|
|
|
|
|
133
|
|
|
if 1 < self._vb: |
|
134
|
|
|
print('Taking {} samples for grid-search'.format(len(self.grid_searcher))) |
|
135
|
|
|
# if kwargs.get('force_overwrite', True): |
|
136
|
|
|
print('Overwritting any existing results and phi matrices found') |
|
137
|
|
|
if self._vb: |
|
138
|
|
|
print('Tuning..') |
|
139
|
|
|
generator = tqdm(self.grid_searcher, total=len(self.grid_searcher), unit='model') |
|
140
|
|
|
else: |
|
141
|
|
|
generator = iter(self.grid_searcher) |
|
142
|
|
|
|
|
143
|
|
|
for i, self.parameter_vector in enumerate(generator): |
|
|
|
|
|
|
144
|
|
|
self._cur_label = self.version(self._labeler(self.parameter_vector)) |
|
145
|
|
|
with warnings.catch_warnings(record=True) as w: |
|
146
|
|
|
# Cause all warnings to always be triggered. |
|
147
|
|
|
# warnings.simplefilter("always") |
|
148
|
|
|
tm, specs = self._model() |
|
149
|
|
|
# assert len(w) == 1 |
|
150
|
|
|
# assert issubclass(w[-1].category, DeprecationWarning) |
|
151
|
|
|
# assert "The value of 'probability_mass_threshold' parameter should be set to 0.5 or higher" == str(w[-1].message) |
|
152
|
|
|
# Trigger a warning. |
|
153
|
|
|
# Verify some things |
|
154
|
|
|
tqdm.write(self._cur_label) |
|
155
|
|
|
tqdm.write("Background: [{}]".format(', '.join([x for x in tm.background_topics]))) |
|
156
|
|
|
tqdm.write("Domain: [{}]".format(', '.join([x for x in tm.domain_topics]))) |
|
157
|
|
|
|
|
158
|
|
|
if 4 < self._vb: |
|
159
|
|
|
tqdm.write(pprint.pformat({k: dict(v, **{k: v for k, v in { |
|
160
|
|
|
'target topics': self._topics_str(tm.get_reg_obj(tm.get_reg_name(k)).topic_names, tm.domain_topics, tm.background_topics), |
|
161
|
|
|
'mods': getattr(tm.get_reg_obj(tm.get_reg_name(k)), 'class_ids', None)}.items()}) for k, v in self.current_reg_specs.items()})) |
|
162
|
|
|
if 3 < self._vb: |
|
163
|
|
|
tqdm.write(pprint.pformat(tm.modalities_dictionary)) |
|
164
|
|
|
self.experiment.init_empty_trackables(tm) |
|
165
|
|
|
self.trainer.train(tm, specs, cache_theta=kwargs.get('cache_theta', True)) |
|
166
|
|
|
self.experiment.save_experiment(save_phi=True) |
|
167
|
|
|
|
|
168
|
|
|
def _topics_str(self, topics, domain, background): |
|
169
|
|
|
if topics == domain: |
|
170
|
|
|
return 'domain' |
|
171
|
|
|
if topics == background: |
|
172
|
|
|
return 'background' |
|
173
|
|
|
return '[{}]'.format(', '.join(topics)) |
|
174
|
|
|
|
|
175
|
|
|
def _model(self): |
|
176
|
|
|
tm = self.trainer.model_factory.construct_model(self._cur_label, self._val('nb_topics'), |
|
177
|
|
|
self._val('collection_passes'), |
|
178
|
|
|
self._val('document_passes'), |
|
179
|
|
|
self._val('background_topics_pct'), |
|
180
|
|
|
{k: v for k, v in |
|
181
|
|
|
{DEFAULT_CLASS_NAME: self._val('default_class_weight'), |
|
182
|
|
|
IDEOLOGY_CLASS_NAME: self._val( |
|
183
|
|
|
'ideology_class_weight')}.items() if v}, |
|
184
|
|
|
self.scores, |
|
185
|
|
|
self._reg_specs.types, |
|
186
|
|
|
reg_settings=self.current_reg_specs) # a dictionary mapping reg_types to reg_specs |
|
187
|
|
|
tr_specs = self.trainer.model_factory.create_train_specs(self._val('collection_passes')) |
|
188
|
|
|
return tm, tr_specs |
|
189
|
|
|
|
|
190
|
|
|
def _val(self, parameter_name): |
|
191
|
|
|
return self.extract(self.parameter_vector, parameter_name.replace('_', '-')) |
|
192
|
|
|
|
|
193
|
|
|
def extract(self, parameters_vector, parameter_name): |
|
194
|
|
|
r = parameter_name.split('.') |
|
195
|
|
|
if len(r) == 1: |
|
196
|
|
|
return self._training_parameters.extract(parameters_vector, parameter_name) |
|
197
|
|
|
elif len(r) == 2: |
|
198
|
|
|
return self._reg_specs.extract(parameters_vector, r[0], r[1]) |
|
199
|
|
|
else: |
|
200
|
|
|
raise ValueError("Either input a training parameter such as 'collection_passes', 'nb_topics', 'ideology_class_weight' or a regularizer's parameter in format such as 'sparse-phi.tau', 'label-regularization-phi-dom-def.tau'") |
|
201
|
|
|
|
|
202
|
|
|
def __format_reg(self, reg_specs, reg_type): |
|
203
|
|
|
if 'name' in reg_specs[reg_type]: |
|
204
|
|
|
return reg_type, reg_specs[reg_type].pop('name') |
|
205
|
|
|
return reg_type |
|
206
|
|
|
|
|
207
|
|
|
|
|
208
|
|
|
############## PARAMETERS MIXTURE ############## |
|
209
|
|
|
|
|
210
|
|
|
|
|
211
|
|
|
@attr.s |
|
212
|
|
|
class RegularizationSpecifications(object): |
|
213
|
|
|
reg_specs = attr.ib(init=True, converter=lambda x: OrderedDict( |
|
214
|
|
|
[(reg_type, ParametersMixture([(param_name, value) for param_name, value in reg_specs])) for reg_type, reg_specs in x]), repr=True, cmp=True) |
|
215
|
|
|
parameter_spans = attr.ib(init=False, default=attr.Factory( |
|
216
|
|
|
lambda self: [span for reg_type, reg_specs in self.reg_specs.items() for span in reg_specs.parameter_spans], |
|
217
|
|
|
takes_self=True)) |
|
218
|
|
|
parameter_names = attr.ib(init=False, |
|
219
|
|
|
default=attr.Factory(lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.parameter_names], takes_self=True)) |
|
220
|
|
|
steady = attr.ib(init=False, default=attr.Factory( |
|
221
|
|
|
lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.steady], |
|
222
|
|
|
takes_self=True)) |
|
223
|
|
|
explorable = attr.ib(init=False, default=attr.Factory( |
|
224
|
|
|
lambda self: ['{}.{}'.format(reg_type, param_name) for reg_type, mixture in self.reg_specs.items() for param_name in mixture.explorable], |
|
225
|
|
|
takes_self=True)) |
|
226
|
|
|
|
|
227
|
|
|
nb_combinations = attr.ib(init=False, default=attr.Factory(lambda self: reduce(lambda i, j: i * j, [v.nb_combinations for v in self.reg_specs.values()]), takes_self=True)) |
|
228
|
|
|
|
|
229
|
|
|
def __getitem__(self, item): |
|
230
|
|
|
return self.reg_specs[item] |
|
231
|
|
|
|
|
232
|
|
|
def __iter__(self): |
|
233
|
|
|
return ((reg_type, reg_params_mixture) for reg_type, reg_params_mixture in self.reg_specs.items()) |
|
|
|
|
|
|
234
|
|
|
|
|
235
|
|
|
def extract(self, parameter_vector, reg_name, reg_param): |
|
236
|
|
|
parameter_vector = list(parameter_vector[-len(self.parameter_spans):]) |
|
237
|
|
|
if reg_name not in self.reg_specs: |
|
238
|
|
|
raise KeyError |
|
239
|
|
|
s = 0 |
|
240
|
|
|
for k, v in self.reg_specs.items(): |
|
241
|
|
|
if k == reg_name: |
|
242
|
|
|
break |
|
243
|
|
|
s += v.length |
|
244
|
|
|
return self.reg_specs[reg_name].extract(parameter_vector[s:], reg_param) |
|
245
|
|
|
|
|
246
|
|
|
@property |
|
247
|
|
|
def types(self): |
|
248
|
|
|
return list(self.reg_specs.keys()) |
|
249
|
|
|
|
|
250
|
|
|
def _conv(value): |
|
251
|
|
|
if type(value) != list: |
|
252
|
|
|
return [value] |
|
253
|
|
|
return value |
|
254
|
|
|
|
|
255
|
|
|
|
|
256
|
|
|
def _build(tuple_list): |
|
257
|
|
|
if len(tuple_list) != len(set([x[0] for x in tuple_list])): |
|
258
|
|
|
raise ValueError("Input tuples should behave like a dict (unique elements as each 1st element)") |
|
259
|
|
|
return OrderedDict([(x[0], _conv(x[1])) for x in tuple_list]) |
|
260
|
|
|
|
|
261
|
|
|
@attr.s |
|
262
|
|
|
class ParametersMixture(object): |
|
263
|
|
|
"""An OrderedDict with keys parameter names and keys either a single object to initialize with or a list of objects intending to model a span/grid of values for grid-search""" |
|
264
|
|
|
_data_hash = attr.ib(init=True, converter=_build) |
|
265
|
|
|
|
|
266
|
|
|
length = attr.ib(init=False, default=attr.Factory(lambda self: len(self._data_hash), takes_self=True)) |
|
267
|
|
|
parameter_names = attr.ib(init=False, default=attr.Factory(lambda self: list(self._data_hash.keys()), takes_self=True)) |
|
268
|
|
|
nb_combinations = attr.ib(init=False, default=attr.Factory(lambda self: reduce(lambda i,j: i*j, [len(v) for v in self._data_hash.values()]), takes_self=True)) |
|
269
|
|
|
steady = attr.ib(init=False, default=attr.Factory(lambda self: [name for name, assumables_values in self._data_hash.items() if len(assumables_values) == 1], takes_self=True)) |
|
270
|
|
|
explorable = attr.ib(init=False, default=attr.Factory(lambda self: [name for name, assumables_values in self._data_hash.items() if len(assumables_values) > 1], takes_self=True)) |
|
271
|
|
|
parameter_spans = attr.ib(init=False, default=attr.Factory(lambda self: [assumable_values for assumable_values in self._data_hash.values()], takes_self=True)) |
|
272
|
|
|
|
|
273
|
|
|
def __contains__(self, item): |
|
274
|
|
|
return item in self._data_hash |
|
275
|
|
|
|
|
276
|
|
|
def __len__(self): |
|
277
|
|
|
return len(self._data_hash) |
|
278
|
|
|
|
|
279
|
|
|
def __getitem__(self, item): |
|
280
|
|
|
return self._data_hash[item] |
|
281
|
|
|
|
|
282
|
|
|
def __iter__(self): |
|
283
|
|
|
return ((k,v) for k,v in self._data_hash.items()) |
|
|
|
|
|
|
284
|
|
|
|
|
285
|
|
|
def extract(self, parameter_vector, parameter): |
|
286
|
|
|
return parameter_vector[self.parameter_names.index(parameter)] |
|
287
|
|
|
|
|
288
|
|
|
@classmethod |
|
289
|
|
|
def from_regularization_settings(cls, reg_settings): |
|
290
|
|
|
return ParametersMixture([(k,v) for k, v in reg_settings]) |
|
291
|
|
|
|
|
292
|
|
|
############## LABELING ############## |
|
293
|
|
|
|
|
294
|
|
|
def _check_extractor(self, attr_attribute_obj, input_value): |
|
295
|
|
|
if not hasattr(input_value, '__call__'): |
|
296
|
|
|
raise ValueError("A callable is required") |
|
297
|
|
|
if input_value.__code__.co_argcount != 2: |
|
298
|
|
|
raise ValueError("Callable should accept exactly 2 arguments. First is a parameter vector (list) and second a parameter name") |
|
299
|
|
|
|
|
300
|
|
|
|
|
301
|
|
|
@attr.s |
|
302
|
|
|
class LabelingDefinition(object): |
|
303
|
|
|
parameters = attr.ib(init=True, converter=list, repr=True, cmp=True) |
|
304
|
|
|
extractor = attr.ib(init=True, validator=_check_extractor) |
|
305
|
|
|
_prefix = attr.ib(init=True, default='') #, converter=str) |
|
306
|
|
|
|
|
307
|
|
|
def __call__(self, parameter_vector): |
|
308
|
|
|
return '_'.join(x for x in [self._prefix] + [self._conv(self.extractor(parameter_vector, param_name)) for param_name in self.parameters] if x) |
|
309
|
|
|
|
|
310
|
|
|
def _conv(self, v): |
|
311
|
|
|
try: |
|
312
|
|
|
v1 = float(v) |
|
313
|
|
|
if v1 >= 1e4: |
|
314
|
|
|
return "{:.2}".format(v1) |
|
315
|
|
|
if int(v1) == v1: |
|
316
|
|
|
return str(int(v1)) |
|
317
|
|
|
return str(v) |
|
318
|
|
|
except ValueError: |
|
319
|
|
|
return str(v) |
|
320
|
|
|
|
|
321
|
|
|
@classmethod |
|
322
|
|
|
def from_training_parameters(cls, training_parameters, prefix='', labeling_params=None, append_static=False, append_explorable=True, preserve_order=False): |
|
323
|
|
|
if labeling_params: |
|
324
|
|
|
if not type(labeling_params) == list: |
|
325
|
|
|
raise ValueError("If given the labeling_params argument should be a list") |
|
326
|
|
|
else: |
|
327
|
|
|
if type(append_static) == list: |
|
328
|
|
|
labeling_params = append_static |
|
329
|
|
|
elif append_static: |
|
330
|
|
|
labeling_params = training_parameters.steady.copy() |
|
331
|
|
|
else: |
|
332
|
|
|
labeling_params = [] |
|
333
|
|
|
if type(append_explorable) == list: |
|
334
|
|
|
labeling_params.extend(append_explorable) |
|
335
|
|
|
elif append_explorable: |
|
336
|
|
|
labeling_params.extend(training_parameters.explorable) |
|
337
|
|
|
if preserve_order: |
|
338
|
|
|
labeling_params = [x for x in training_parameters.parameter_names if x in labeling_params] |
|
339
|
|
|
return LabelingDefinition(labeling_params, lambda vector, param: training_parameters.extract(vector, param), prefix) |
|
340
|
|
|
|
|
341
|
|
|
@classmethod |
|
342
|
|
|
def from_tuner(cls, tuner, prefix='', labeling_params=None, append_static=False, append_explorable=True, preserve_order=False, parameter_set='training|regularization'): |
|
343
|
|
|
return LabelingDefinition(cls.select(tuner, labeling_params=labeling_params, append_static=append_static, append_explorable=append_explorable, preserve_order=preserve_order, parameter_set=parameter_set), |
|
344
|
|
|
lambda vector, param: tuner.extract(vector, param), |
|
345
|
|
|
prefix) |
|
346
|
|
|
|
|
347
|
|
|
@classmethod |
|
348
|
|
|
def select(cls, tuner, **kwargs): |
|
349
|
|
|
labeling_params = [] |
|
350
|
|
|
if kwargs.get('labeling_params', None): |
|
351
|
|
|
if not type(kwargs.get('labeling_params', None)) == list: |
|
352
|
|
|
raise ValueError("If given, the labeling_params keyword-argument should be a list") |
|
353
|
|
|
labeling_params = kwargs['labeling_params'] |
|
354
|
|
|
else: |
|
355
|
|
|
if type(kwargs.get('append_static', False)) == list: |
|
356
|
|
|
labeling_params.extend(kwargs['append_static']) |
|
357
|
|
|
elif kwargs.get('append_static', False): |
|
358
|
|
|
labeling_params.extend([x for el in kwargs.get('parameter_set', 'training|regularization').split('|') for x in tuner[el].steady]) |
|
359
|
|
|
if type(kwargs.get('append_explorable', False)) == list: |
|
360
|
|
|
labeling_params.extend(kwargs['append_explorable']) |
|
361
|
|
|
elif kwargs.get('append_explorable', False): |
|
362
|
|
|
labeling_params.extend([x for el in kwargs.get('parameter_set', 'training|regularization').split('|') for x in tuner[el].explorable]) |
|
363
|
|
|
if kwargs.get('preserve_order', False): |
|
364
|
|
|
# labeling_params = [x for x in [y for el in kwargs.get('parameter_set', 'training|regularization').split('|') for y in tuner[el].parameter_names] if x in labeling_params] |
|
365
|
|
|
labeling_params = [x for x in tuner.parameter_names if x in labeling_params] |
|
366
|
|
|
return labeling_params |
|
367
|
|
|
|