1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# coding=utf-8 |
3
|
|
|
from __future__ import division, print_function, unicode_literals |
4
|
|
|
|
5
|
|
|
import os |
6
|
|
|
from collections import OrderedDict, defaultdict |
7
|
|
|
from copy import copy, deepcopy |
8
|
|
|
|
9
|
|
|
from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize, |
10
|
|
|
load_config_file, undogmatize) |
11
|
|
|
from sacred.config.config_summary import ConfigSummary |
12
|
|
|
from sacred.host_info import get_host_info |
13
|
|
|
from sacred.randomness import create_rnd, get_seed |
14
|
|
|
from sacred.run import Run |
15
|
|
|
from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger, |
16
|
|
|
get_by_dotted_path, is_prefix, rel_path, |
17
|
|
|
iterate_flattened, set_by_dotted_path, |
18
|
|
|
recursive_update, iter_prefixes, join_paths) |
19
|
|
|
|
20
|
|
|
|
21
|
|
|
class Scaffold(object): |
22
|
|
|
def __init__(self, config_scopes, subrunners, path, captured_functions, |
23
|
|
|
commands, named_configs, config_hooks, generate_seed): |
24
|
|
|
self.config_scopes = config_scopes |
25
|
|
|
self.named_configs = named_configs |
26
|
|
|
self.subrunners = subrunners |
27
|
|
|
self.path = path |
28
|
|
|
self.generate_seed = generate_seed |
29
|
|
|
self.config_hooks = config_hooks |
30
|
|
|
self.config_updates = {} |
31
|
|
|
self.named_configs_to_use = [] |
32
|
|
|
self.config = {} |
33
|
|
|
self.fallback = None |
34
|
|
|
self.presets = {} |
35
|
|
|
self.fixture = None # TODO: rename |
36
|
|
|
self.logger = None |
37
|
|
|
self.seed = None |
38
|
|
|
self.rnd = None |
39
|
|
|
self._captured_functions = captured_functions |
40
|
|
|
self.commands = commands |
41
|
|
|
self.config_mods = None |
42
|
|
|
self.summaries = [] |
43
|
|
|
self.captured_args = {join_paths(cf.prefix, n) |
44
|
|
|
for cf in self._captured_functions |
45
|
|
|
for n in cf.signature.arguments} |
46
|
|
|
self.captured_args.add('__doc__') # allow setting the config docstring |
47
|
|
|
|
48
|
|
|
def set_up_seed(self, rnd=None): |
49
|
|
|
if self.seed is not None: |
50
|
|
|
return |
51
|
|
|
|
52
|
|
|
self.seed = self.config.get('seed') |
53
|
|
|
if self.seed is None: |
54
|
|
|
self.seed = get_seed(rnd) |
55
|
|
|
|
56
|
|
|
self.rnd = create_rnd(self.seed) |
57
|
|
|
|
58
|
|
|
if self.generate_seed: |
59
|
|
|
self.config['seed'] = self.seed |
60
|
|
|
|
61
|
|
|
if 'seed' in self.config and 'seed' in self.config_mods.added: |
62
|
|
|
self.config_mods.modified.add('seed') |
63
|
|
|
self.config_mods.added -= {'seed'} |
64
|
|
|
|
65
|
|
|
# Hierarchically set the seed of proper subrunners |
66
|
|
|
for subrunner_path, subrunner in reversed(list( |
67
|
|
|
self.subrunners.items())): |
68
|
|
|
if is_prefix(self.path, subrunner_path): |
69
|
|
|
subrunner.set_up_seed(self.rnd) |
70
|
|
|
|
71
|
|
|
def gather_fallbacks(self): |
72
|
|
|
fallback = {'_log': self.logger} |
73
|
|
|
for sr_path, subrunner in self.subrunners.items(): |
74
|
|
|
if self.path and is_prefix(self.path, sr_path): |
75
|
|
|
path = sr_path[len(self.path):].strip('.') |
76
|
|
|
set_by_dotted_path(fallback, path, subrunner.config) |
77
|
|
|
else: |
78
|
|
|
set_by_dotted_path(fallback, sr_path, subrunner.config) |
79
|
|
|
|
80
|
|
|
# dogmatize to make the subrunner configurations read-only |
81
|
|
|
self.fallback = dogmatize(fallback) |
82
|
|
|
self.fallback.revelation() |
83
|
|
|
|
84
|
|
|
def run_named_config(self, config_name): |
85
|
|
|
if os.path.exists(config_name): |
86
|
|
|
nc = ConfigDict(load_config_file(config_name)) |
87
|
|
|
else: |
88
|
|
|
nc = self.named_configs[config_name] |
89
|
|
|
|
90
|
|
|
cfg = nc(fixed=self.get_config_updates_recursive(), |
91
|
|
|
preset=self.presets, |
92
|
|
|
fallback=self.fallback) |
93
|
|
|
|
94
|
|
|
return undogmatize(cfg) |
95
|
|
|
|
96
|
|
|
def set_up_config(self): |
97
|
|
|
self.config, self.summaries = chain_evaluate_config_scopes( |
98
|
|
|
self.config_scopes, |
99
|
|
|
fixed=self.config_updates, |
100
|
|
|
preset=self.config, |
101
|
|
|
fallback=self.fallback) |
102
|
|
|
|
103
|
|
|
self.get_config_modifications() |
104
|
|
|
|
105
|
|
|
def run_config_hooks(self, config, config_updates, command_name, logger): |
106
|
|
|
final_cfg_updates = {} |
107
|
|
|
for ch in self.config_hooks: |
108
|
|
|
cfg_upup = ch(deepcopy(config), command_name, logger) |
109
|
|
|
if cfg_upup: |
110
|
|
|
recursive_update(final_cfg_updates, cfg_upup) |
111
|
|
|
recursive_update(final_cfg_updates, config_updates) |
112
|
|
|
return final_cfg_updates |
113
|
|
|
|
114
|
|
|
def get_config_modifications(self): |
115
|
|
|
self.config_mods = ConfigSummary( |
116
|
|
|
added={key |
117
|
|
|
for key, value in iterate_flattened(self.config_updates)}) |
118
|
|
|
for cfg_summary in self.summaries: |
119
|
|
|
self.config_mods.update_from(cfg_summary) |
120
|
|
|
|
121
|
|
|
def get_config_updates_recursive(self): |
122
|
|
|
config_updates = self.config_updates.copy() |
123
|
|
|
for sr_path, subrunner in self.subrunners.items(): |
124
|
|
|
if not is_prefix(self.path, sr_path): |
125
|
|
|
continue |
126
|
|
|
update = subrunner.get_config_updates_recursive() |
127
|
|
|
if update: |
128
|
|
|
config_updates[rel_path(self.path, sr_path)] = update |
129
|
|
|
return config_updates |
130
|
|
|
|
131
|
|
|
def get_fixture(self): |
132
|
|
|
if self.fixture is not None: |
133
|
|
|
return self.fixture |
134
|
|
|
|
135
|
|
|
def get_fixture_recursive(runner): |
136
|
|
|
for sr_path, subrunner in runner.subrunners.items(): |
137
|
|
|
# I am not sure if it is necessary to trigger all |
138
|
|
|
subrunner.get_fixture() |
139
|
|
|
get_fixture_recursive(subrunner) |
140
|
|
|
sub_fix = copy(subrunner.config) |
141
|
|
|
sub_path = sr_path |
142
|
|
|
if is_prefix(self.path, sub_path): |
143
|
|
|
sub_path = sr_path[len(self.path):].strip('.') |
144
|
|
|
# Note: This might fail if we allow non-dict fixtures |
145
|
|
|
set_by_dotted_path(self.fixture, sub_path, sub_fix) |
146
|
|
|
|
147
|
|
|
self.fixture = copy(self.config) |
148
|
|
|
get_fixture_recursive(self) |
149
|
|
|
|
150
|
|
|
return self.fixture |
151
|
|
|
|
152
|
|
|
def finalize_initialization(self, run): |
153
|
|
|
# look at seed again, because it might have changed during the |
154
|
|
|
# configuration process |
155
|
|
|
if 'seed' in self.config: |
156
|
|
|
self.seed = self.config['seed'] |
157
|
|
|
self.rnd = create_rnd(self.seed) |
158
|
|
|
|
159
|
|
|
for cfunc in self._captured_functions: |
160
|
|
|
cfunc.logger = self.logger.getChild(cfunc.__name__) |
161
|
|
|
cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix, |
162
|
|
|
default={}) |
163
|
|
|
seed = get_seed(self.rnd) |
164
|
|
|
cfunc.rnd = create_rnd(seed) |
165
|
|
|
cfunc.run = run |
166
|
|
|
|
167
|
|
|
if not run.force: |
168
|
|
|
self._warn_about_suspicious_changes() |
169
|
|
|
|
170
|
|
|
def _warn_about_suspicious_changes(self): |
171
|
|
|
for add in sorted(self.config_mods.added): |
172
|
|
|
if not set(iter_prefixes(add)).intersection(self.captured_args): |
173
|
|
|
raise KeyError('Added a new config entry "{}" that is not used' |
174
|
|
|
' anywhere'.format(add)) |
175
|
|
|
else: |
176
|
|
|
self.logger.warning('Added new config entry: "%s"' % add) |
177
|
|
|
|
178
|
|
|
for key, (type_old, type_new) in self.config_mods.typechanged.items(): |
179
|
|
|
if type_old in (int, float) and type_new in (int, float): |
180
|
|
|
continue |
181
|
|
|
self.logger.warning( |
182
|
|
|
'Changed type of config entry "%s" from %s to %s' % |
183
|
|
|
(key, type_old.__name__, type_new.__name__)) |
184
|
|
|
|
185
|
|
|
for cfg_summary in self.summaries: |
186
|
|
|
for key in cfg_summary.ignored_fallbacks: |
187
|
|
|
self.logger.warning( |
188
|
|
|
'Ignored attempt to set value of "%s", because it is an ' |
189
|
|
|
'ingredient.' % key |
190
|
|
|
) |
191
|
|
|
|
192
|
|
|
def __repr__(self): |
193
|
|
|
return "<Scaffold: '{}'>".format(self.path) |
194
|
|
|
|
195
|
|
|
|
196
|
|
|
def get_configuration(scaffolding): |
197
|
|
|
config = {} |
198
|
|
|
for sc_path, scaffold in reversed(list(scaffolding.items())): |
199
|
|
|
if not scaffold.config: |
200
|
|
|
continue |
201
|
|
|
if sc_path: |
202
|
|
|
set_by_dotted_path(config, sc_path, scaffold.config) |
203
|
|
|
else: |
204
|
|
|
config.update(scaffold.config) |
205
|
|
|
return config |
206
|
|
|
|
207
|
|
|
|
208
|
|
|
def distribute_named_configs(scaffolding, named_configs): |
209
|
|
|
for ncfg in named_configs: |
210
|
|
|
if os.path.exists(ncfg): |
211
|
|
|
scaffolding[''].use_named_config(ncfg) |
212
|
|
|
else: |
213
|
|
|
path, _, cfg_name = ncfg.rpartition('.') |
214
|
|
|
if path not in scaffolding: |
215
|
|
|
raise KeyError('Ingredient for named config "{}" not found' |
216
|
|
|
.format(ncfg)) |
217
|
|
|
scaffolding[path].use_named_config(cfg_name) |
218
|
|
|
|
219
|
|
|
|
220
|
|
|
def initialize_logging(experiment, scaffolding, log_level=None): |
221
|
|
|
if experiment.logger is None: |
222
|
|
|
root_logger = create_basic_stream_logger() |
223
|
|
|
else: |
224
|
|
|
root_logger = experiment.logger |
225
|
|
|
|
226
|
|
|
for sc_path, scaffold in scaffolding.items(): |
227
|
|
|
if sc_path: |
228
|
|
|
scaffold.logger = root_logger.getChild(sc_path) |
229
|
|
|
else: |
230
|
|
|
scaffold.logger = root_logger |
231
|
|
|
|
232
|
|
|
# set log level |
233
|
|
|
if log_level is not None: |
234
|
|
|
try: |
235
|
|
|
lvl = int(log_level) |
236
|
|
|
except ValueError: |
237
|
|
|
lvl = log_level |
238
|
|
|
root_logger.setLevel(lvl) |
239
|
|
|
|
240
|
|
|
return root_logger, root_logger.getChild(experiment.path) |
241
|
|
|
|
242
|
|
|
|
243
|
|
|
def create_scaffolding(experiment, sorted_ingredients): |
244
|
|
|
scaffolding = OrderedDict() |
245
|
|
|
for ingredient in sorted_ingredients[:-1]: |
246
|
|
|
scaffolding[ingredient] = Scaffold( |
247
|
|
|
config_scopes=ingredient.configurations, |
248
|
|
|
subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) |
249
|
|
|
for m in ingredient.ingredients]), |
250
|
|
|
path=ingredient.path, |
251
|
|
|
captured_functions=ingredient.captured_functions, |
252
|
|
|
commands=ingredient.commands, |
253
|
|
|
named_configs=ingredient.named_configs, |
254
|
|
|
config_hooks=ingredient.config_hooks, |
255
|
|
|
generate_seed=False) |
256
|
|
|
|
257
|
|
|
scaffolding[experiment] = Scaffold( |
258
|
|
|
experiment.configurations, |
259
|
|
|
subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) |
260
|
|
|
for m in experiment.ingredients]), |
261
|
|
|
path='', |
262
|
|
|
captured_functions=experiment.captured_functions, |
263
|
|
|
commands=experiment.commands, |
264
|
|
|
named_configs=experiment.named_configs, |
265
|
|
|
config_hooks=experiment.config_hooks, |
266
|
|
|
generate_seed=True) |
267
|
|
|
|
268
|
|
|
scaffolding_ret = OrderedDict([ |
269
|
|
|
(sc.path, sc) |
270
|
|
|
for sc in scaffolding.values() |
271
|
|
|
]) |
272
|
|
|
if len(scaffolding_ret) != len(scaffolding): |
273
|
|
|
raise ValueError( |
274
|
|
|
'The pathes of the ingredients are not unique. ' |
275
|
|
|
'{}'.format([s.path for s in scaffolding]) |
276
|
|
|
) |
277
|
|
|
|
278
|
|
|
return scaffolding_ret |
279
|
|
|
|
280
|
|
|
|
281
|
|
|
def gather_ingredients_topological(ingredient): |
282
|
|
|
sub_ingredients = defaultdict(int) |
283
|
|
|
for sub_ing, depth in ingredient.traverse_ingredients(): |
284
|
|
|
sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth) |
285
|
|
|
return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x]) |
286
|
|
|
|
287
|
|
|
|
288
|
|
|
def get_config_modifications(scaffolding): |
289
|
|
|
config_modifications = ConfigSummary() |
290
|
|
|
for sc_path, scaffold in scaffolding.items(): |
291
|
|
|
config_modifications.update_add(scaffold.config_mods, path=sc_path) |
292
|
|
|
return config_modifications |
293
|
|
|
|
294
|
|
|
|
295
|
|
|
def get_command(scaffolding, command_path): |
296
|
|
|
path, _, command_name = command_path.rpartition('.') |
297
|
|
|
if path not in scaffolding: |
298
|
|
|
raise KeyError('Ingredient for command "%s" not found.' % command_path) |
299
|
|
|
|
300
|
|
|
if command_name in scaffolding[path].commands: |
301
|
|
|
return scaffolding[path].commands[command_name] |
302
|
|
|
else: |
303
|
|
|
if path: |
304
|
|
|
raise KeyError('Command "%s" not found in ingredient "%s"' % |
305
|
|
|
(command_name, path)) |
306
|
|
|
else: |
307
|
|
|
raise KeyError('Command "%s" not found' % command_name) |
308
|
|
|
|
309
|
|
|
|
310
|
|
|
def find_best_match(path, prefixes): |
311
|
|
|
"""Find the Ingredient that shares the longest prefix with path.""" |
312
|
|
|
path_parts = path.split('.') |
313
|
|
|
for p in prefixes: |
314
|
|
|
if len(p) <= len(path_parts) and p == path_parts[:len(p)]: |
315
|
|
|
return '.'.join(p), '.'.join(path_parts[len(p):]) |
316
|
|
|
return '', path |
317
|
|
|
|
318
|
|
|
|
319
|
|
|
def distribute_presets(prefixes, scaffolding, config_updates): |
320
|
|
|
for path, value in iterate_flattened(config_updates): |
321
|
|
|
scaffold_name, suffix = find_best_match(path, prefixes) |
322
|
|
|
scaff = scaffolding[scaffold_name] |
323
|
|
|
set_by_dotted_path(scaff.presets, suffix, value) |
324
|
|
|
|
325
|
|
|
|
326
|
|
|
def distribute_config_updates(prefixes, scaffolding, config_updates): |
327
|
|
|
for path, value in iterate_flattened(config_updates): |
328
|
|
|
scaffold_name, suffix = find_best_match(path, prefixes) |
329
|
|
|
scaff = scaffolding[scaffold_name] |
330
|
|
|
set_by_dotted_path(scaff.config_updates, suffix, value) |
331
|
|
|
|
332
|
|
|
|
333
|
|
|
def get_scaffolding_and_config_name(named_config, scaffolding): |
334
|
|
|
if os.path.exists(named_config): |
335
|
|
|
path, cfg_name = '', named_config |
336
|
|
|
else: |
337
|
|
|
path, _, cfg_name = named_config.rpartition('.') |
338
|
|
|
|
339
|
|
|
if path not in scaffolding: |
340
|
|
|
raise KeyError('Ingredient for named config "{}" not found' |
341
|
|
|
.format(named_config)) |
342
|
|
|
scaff = scaffolding[path] |
343
|
|
|
return scaff, cfg_name |
344
|
|
|
|
345
|
|
|
|
346
|
|
|
def create_run(experiment, command_name, config_updates=None, |
347
|
|
|
named_configs=(), force=False, log_level=None): |
348
|
|
|
|
349
|
|
|
sorted_ingredients = gather_ingredients_topological(experiment) |
350
|
|
|
scaffolding = create_scaffolding(experiment, sorted_ingredients) |
351
|
|
|
# get all split non-empty prefixes sorted from deepest to shallowest |
352
|
|
|
prefixes = sorted([s.split('.') for s in scaffolding if s != ''], |
353
|
|
|
reverse=True, key=lambda p: len(p)) |
354
|
|
|
|
355
|
|
|
# --------- configuration process ------------------- |
356
|
|
|
|
357
|
|
|
# Phase 1: Config updates |
358
|
|
|
config_updates = config_updates or {} |
359
|
|
|
config_updates = convert_to_nested_dict(config_updates) |
360
|
|
|
root_logger, run_logger = initialize_logging(experiment, scaffolding, |
361
|
|
|
log_level) |
362
|
|
|
distribute_config_updates(prefixes, scaffolding, config_updates) |
363
|
|
|
|
364
|
|
|
# Phase 2: Named Configs |
365
|
|
|
for ncfg in named_configs: |
366
|
|
|
scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) |
367
|
|
|
scaff.gather_fallbacks() |
368
|
|
|
ncfg_updates = scaff.run_named_config(cfg_name) |
369
|
|
|
distribute_presets(prefixes, scaffolding, ncfg_updates) |
370
|
|
|
for ncfg_key, value in iterate_flattened(ncfg_updates): |
371
|
|
|
set_by_dotted_path(config_updates, |
372
|
|
|
join_paths(scaff.path, ncfg_key), |
373
|
|
|
value) |
374
|
|
|
|
375
|
|
|
distribute_config_updates(prefixes, scaffolding, config_updates) |
376
|
|
|
|
377
|
|
|
# Phase 3: Normal config scopes |
378
|
|
|
for scaffold in scaffolding.values(): |
379
|
|
|
scaffold.gather_fallbacks() |
380
|
|
|
scaffold.set_up_config() |
381
|
|
|
|
382
|
|
|
# update global config |
383
|
|
|
config = get_configuration(scaffolding) |
384
|
|
|
# run config hooks |
385
|
|
|
config_updates = scaffold.run_config_hooks(config, config_updates, |
386
|
|
|
command_name, run_logger) |
387
|
|
|
|
388
|
|
|
# Phase 4: finalize seeding |
389
|
|
|
for scaffold in reversed(list(scaffolding.values())): |
390
|
|
|
scaffold.set_up_seed() # partially recursive |
391
|
|
|
|
392
|
|
|
config = get_configuration(scaffolding) |
393
|
|
|
config_modifications = get_config_modifications(scaffolding) |
394
|
|
|
|
395
|
|
|
# ---------------------------------------------------- |
396
|
|
|
|
397
|
|
|
experiment_info = experiment.get_experiment_info() |
398
|
|
|
host_info = get_host_info() |
399
|
|
|
main_function = get_command(scaffolding, command_name) |
400
|
|
|
pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] |
401
|
|
|
post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] |
402
|
|
|
|
403
|
|
|
run = Run(config, config_modifications, main_function, |
404
|
|
|
copy(experiment.observers), root_logger, run_logger, |
405
|
|
|
experiment_info, host_info, pre_runs, post_runs, |
406
|
|
|
experiment.captured_out_filter) |
407
|
|
|
|
408
|
|
|
if hasattr(main_function, 'unobserved'): |
409
|
|
|
run.unobserved = main_function.unobserved |
410
|
|
|
|
411
|
|
|
run.force = force |
412
|
|
|
|
413
|
|
|
for scaffold in scaffolding.values(): |
414
|
|
|
scaffold.finalize_initialization(run=run) |
415
|
|
|
|
416
|
|
|
return run |
417
|
|
|
|