1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# coding=utf-8 |
3
|
|
|
from __future__ import division, print_function, unicode_literals |
4
|
|
|
|
5
|
|
|
import os |
6
|
|
|
from collections import OrderedDict, defaultdict |
7
|
|
|
from copy import copy, deepcopy |
8
|
|
|
|
9
|
|
|
from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize, |
10
|
|
|
load_config_file, undogmatize) |
11
|
|
|
from sacred.config.config_summary import ConfigSummary |
12
|
|
|
from sacred.host_info import get_host_info |
13
|
|
|
from sacred.randomness import create_rnd, get_seed |
14
|
|
|
from sacred.run import Run |
15
|
|
|
from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger, |
16
|
|
|
get_by_dotted_path, is_prefix, |
17
|
|
|
iterate_flattened, set_by_dotted_path, |
18
|
|
|
recursive_update, iter_prefixes, join_paths) |
19
|
|
|
|
20
|
|
|
__sacred__ = True # marks files that should be filtered from stack traces |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
class Scaffold(object): |
24
|
|
|
def __init__(self, config_scopes, subrunners, path, captured_functions, |
25
|
|
|
commands, named_configs, config_hooks, generate_seed): |
26
|
|
|
self.config_scopes = config_scopes |
27
|
|
|
self.named_configs = named_configs |
28
|
|
|
self.subrunners = subrunners |
29
|
|
|
self.path = path |
30
|
|
|
self.generate_seed = generate_seed |
31
|
|
|
self.config_hooks = config_hooks |
32
|
|
|
self.config_updates = {} |
33
|
|
|
self.named_configs_to_use = [] |
34
|
|
|
self.config = {} |
35
|
|
|
self.fallback = None |
36
|
|
|
self.presets = {} |
37
|
|
|
self.fixture = None # TODO: rename |
38
|
|
|
self.logger = None |
39
|
|
|
self.seed = None |
40
|
|
|
self.rnd = None |
41
|
|
|
self._captured_functions = captured_functions |
42
|
|
|
self.commands = commands |
43
|
|
|
self.config_mods = None |
44
|
|
|
self.summaries = [] |
45
|
|
|
self.captured_args = {join_paths(cf.prefix, n) |
46
|
|
|
for cf in self._captured_functions |
47
|
|
|
for n in cf.signature.arguments} |
48
|
|
|
self.captured_args.add('__doc__') # allow setting the config docstring |
49
|
|
|
|
50
|
|
|
def set_up_seed(self, rnd=None): |
51
|
|
|
if self.seed is not None: |
52
|
|
|
return |
53
|
|
|
|
54
|
|
|
self.seed = self.config.get('seed') |
55
|
|
|
if self.seed is None: |
56
|
|
|
self.seed = get_seed(rnd) |
57
|
|
|
|
58
|
|
|
self.rnd = create_rnd(self.seed) |
59
|
|
|
|
60
|
|
|
if self.generate_seed: |
61
|
|
|
self.config['seed'] = self.seed |
62
|
|
|
|
63
|
|
|
if 'seed' in self.config and 'seed' in self.config_mods.added: |
64
|
|
|
self.config_mods.modified.add('seed') |
65
|
|
|
self.config_mods.added -= {'seed'} |
66
|
|
|
|
67
|
|
|
# Hierarchically set the seed of proper subrunners |
68
|
|
|
for subrunner_path, subrunner in reversed(list( |
69
|
|
|
self.subrunners.items())): |
70
|
|
|
if is_prefix(self.path, subrunner_path): |
71
|
|
|
subrunner.set_up_seed(self.rnd) |
72
|
|
|
|
73
|
|
|
def gather_fallbacks(self): |
74
|
|
|
fallback = {} |
75
|
|
|
for sr_path, subrunner in self.subrunners.items(): |
76
|
|
|
if self.path and is_prefix(self.path, sr_path): |
77
|
|
|
path = sr_path[len(self.path):].strip('.') |
78
|
|
|
set_by_dotted_path(fallback, path, subrunner.config) |
79
|
|
|
else: |
80
|
|
|
set_by_dotted_path(fallback, sr_path, subrunner.config) |
81
|
|
|
|
82
|
|
|
# dogmatize to make the subrunner configurations read-only |
83
|
|
|
self.fallback = dogmatize(fallback) |
84
|
|
|
self.fallback.revelation() |
85
|
|
|
|
86
|
|
|
def run_named_config(self, config_name): |
87
|
|
|
if os.path.exists(config_name): |
88
|
|
|
nc = ConfigDict(load_config_file(config_name)) |
89
|
|
|
else: |
90
|
|
|
nc = self.named_configs[config_name] |
91
|
|
|
|
92
|
|
|
cfg = nc(fixed=self.get_config_updates_recursive(), |
93
|
|
|
preset=self.presets, |
94
|
|
|
fallback=self.fallback) |
95
|
|
|
|
96
|
|
|
return undogmatize(cfg) |
97
|
|
|
|
98
|
|
|
def set_up_config(self): |
99
|
|
|
self.config, self.summaries = chain_evaluate_config_scopes( |
100
|
|
|
self.config_scopes, |
101
|
|
|
fixed=self.config_updates, |
102
|
|
|
preset=self.config, |
103
|
|
|
fallback=self.fallback) |
104
|
|
|
|
105
|
|
|
self.get_config_modifications() |
106
|
|
|
|
107
|
|
|
def run_config_hooks(self, config, config_updates, command_name, logger): |
108
|
|
|
final_cfg_updates = {} |
109
|
|
|
for ch in self.config_hooks: |
110
|
|
|
cfg_upup = ch(deepcopy(config), command_name, logger) |
111
|
|
|
if cfg_upup: |
112
|
|
|
recursive_update(final_cfg_updates, cfg_upup) |
113
|
|
|
recursive_update(final_cfg_updates, config_updates) |
114
|
|
|
return final_cfg_updates |
115
|
|
|
|
116
|
|
|
def get_config_modifications(self): |
117
|
|
|
self.config_mods = ConfigSummary( |
118
|
|
|
added={key |
119
|
|
|
for key, value in iterate_flattened(self.config_updates)}) |
120
|
|
|
for cfg_summary in self.summaries: |
121
|
|
|
self.config_mods.update_from(cfg_summary) |
122
|
|
|
|
123
|
|
|
def get_config_updates_recursive(self): |
124
|
|
|
config_updates = self.config_updates.copy() |
125
|
|
|
for sr_path, subrunner in self.subrunners.items(): |
126
|
|
|
config_updates[sr_path] = subrunner.get_config_updates_recursive() |
127
|
|
|
return config_updates |
128
|
|
|
|
129
|
|
|
def get_fixture(self): |
130
|
|
|
if self.fixture is not None: |
131
|
|
|
return self.fixture |
132
|
|
|
|
133
|
|
|
def get_fixture_recursive(runner): |
134
|
|
|
for sr_path, subrunner in runner.subrunners.items(): |
135
|
|
|
# I am not sure if it is necessary to trigger all |
136
|
|
|
subrunner.get_fixture() |
137
|
|
|
get_fixture_recursive(subrunner) |
138
|
|
|
sub_fix = copy(subrunner.config) |
139
|
|
|
sub_path = sr_path |
140
|
|
|
if is_prefix(self.path, sub_path): |
141
|
|
|
sub_path = sr_path[len(self.path):].strip('.') |
142
|
|
|
# Note: This might fail if we allow non-dict fixtures |
143
|
|
|
set_by_dotted_path(self.fixture, sub_path, sub_fix) |
144
|
|
|
|
145
|
|
|
self.fixture = copy(self.config) |
146
|
|
|
get_fixture_recursive(self) |
147
|
|
|
|
148
|
|
|
return self.fixture |
149
|
|
|
|
150
|
|
|
def finalize_initialization(self, run): |
151
|
|
|
# look at seed again, because it might have changed during the |
152
|
|
|
# configuration process |
153
|
|
|
if 'seed' in self.config: |
154
|
|
|
self.seed = self.config['seed'] |
155
|
|
|
self.rnd = create_rnd(self.seed) |
156
|
|
|
|
157
|
|
|
for cfunc in self._captured_functions: |
158
|
|
|
cfunc.logger = self.logger.getChild(cfunc.__name__) |
159
|
|
|
cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix, |
160
|
|
|
default={}) |
161
|
|
|
seed = get_seed(self.rnd) |
162
|
|
|
cfunc.rnd = create_rnd(seed) |
163
|
|
|
cfunc.run = run |
164
|
|
|
|
165
|
|
|
if not run.force: |
166
|
|
|
self._warn_about_suspicious_changes() |
167
|
|
|
|
168
|
|
|
def _warn_about_suspicious_changes(self): |
169
|
|
|
for add in sorted(self.config_mods.added): |
170
|
|
|
if not set(iter_prefixes(add)).intersection(self.captured_args): |
171
|
|
|
raise KeyError('Added a new config entry "{}" that is not used' |
172
|
|
|
' anywhere'.format(add)) |
173
|
|
|
else: |
174
|
|
|
self.logger.warning('Added new config entry: "%s"' % add) |
175
|
|
|
|
176
|
|
|
for key, (type_old, type_new) in self.config_mods.typechanged.items(): |
177
|
|
|
if type_old in (int, float) and type_new in (int, float): |
178
|
|
|
continue |
179
|
|
|
self.logger.warning( |
180
|
|
|
'Changed type of config entry "%s" from %s to %s' % |
181
|
|
|
(key, type_old.__name__, type_new.__name__)) |
182
|
|
|
|
183
|
|
|
for cfg_summary in self.summaries: |
184
|
|
|
for key in cfg_summary.ignored_fallbacks: |
185
|
|
|
self.logger.warning( |
186
|
|
|
'Ignored attempt to set value of "%s", because it is an ' |
187
|
|
|
'ingredient.' % key |
188
|
|
|
) |
189
|
|
|
|
190
|
|
|
def __repr__(self): |
191
|
|
|
return "<Scaffold: '{}'>".format(self.path) |
192
|
|
|
|
193
|
|
|
|
194
|
|
|
def get_configuration(scaffolding): |
195
|
|
|
config = {} |
196
|
|
|
for sc_path, scaffold in reversed(list(scaffolding.items())): |
197
|
|
|
if not scaffold.config: |
198
|
|
|
continue |
199
|
|
|
if sc_path: |
200
|
|
|
set_by_dotted_path(config, sc_path, scaffold.config) |
201
|
|
|
else: |
202
|
|
|
config.update(scaffold.config) |
203
|
|
|
return config |
204
|
|
|
|
205
|
|
|
|
206
|
|
|
def distribute_named_configs(scaffolding, named_configs): |
207
|
|
|
for ncfg in named_configs: |
208
|
|
|
if os.path.exists(ncfg): |
209
|
|
|
scaffolding[''].use_named_config(ncfg) |
210
|
|
|
else: |
211
|
|
|
path, _, cfg_name = ncfg.rpartition('.') |
212
|
|
|
if path not in scaffolding: |
213
|
|
|
raise KeyError('Ingredient for named config "{}" not found' |
214
|
|
|
.format(ncfg)) |
215
|
|
|
scaffolding[path].use_named_config(cfg_name) |
216
|
|
|
|
217
|
|
|
|
218
|
|
|
def initialize_logging(experiment, scaffolding): |
219
|
|
|
if experiment.logger is None: |
220
|
|
|
root_logger = create_basic_stream_logger() |
221
|
|
|
else: |
222
|
|
|
root_logger = experiment.logger |
223
|
|
|
|
224
|
|
|
for sc_path, scaffold in scaffolding.items(): |
225
|
|
|
if sc_path: |
226
|
|
|
scaffold.logger = root_logger.getChild(sc_path) |
227
|
|
|
else: |
228
|
|
|
scaffold.logger = root_logger |
229
|
|
|
|
230
|
|
|
return root_logger, root_logger.getChild(experiment.path) |
231
|
|
|
|
232
|
|
|
|
233
|
|
|
def create_scaffolding(experiment, sorted_ingredients): |
234
|
|
|
scaffolding = OrderedDict() |
235
|
|
|
for ingredient in sorted_ingredients[:-1]: |
236
|
|
|
scaffolding[ingredient] = Scaffold( |
237
|
|
|
config_scopes=ingredient.configurations, |
238
|
|
|
subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) |
239
|
|
|
for m in ingredient.ingredients]), |
240
|
|
|
path=ingredient.path if ingredient != experiment else '', |
241
|
|
|
captured_functions=ingredient.captured_functions, |
242
|
|
|
commands=ingredient.commands, |
243
|
|
|
named_configs=ingredient.named_configs, |
244
|
|
|
config_hooks=ingredient.config_hooks, |
245
|
|
|
generate_seed=False) |
246
|
|
|
|
247
|
|
|
scaffolding[experiment] = Scaffold( |
248
|
|
|
experiment.configurations, |
249
|
|
|
subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) |
250
|
|
|
for m in experiment.ingredients]), |
251
|
|
|
path=experiment.path if experiment != experiment else '', |
252
|
|
|
captured_functions=experiment.captured_functions, |
253
|
|
|
commands=experiment.commands, |
254
|
|
|
named_configs=experiment.named_configs, |
255
|
|
|
config_hooks=experiment.config_hooks, |
256
|
|
|
generate_seed=True) |
257
|
|
|
|
258
|
|
|
scaffolding_ret = OrderedDict([ |
259
|
|
|
(sc.path, sc) |
260
|
|
|
for sc in scaffolding.values() |
261
|
|
|
]) |
262
|
|
|
if len(scaffolding_ret) != len(scaffolding): |
263
|
|
|
raise ValueError( |
264
|
|
|
'The pathes of the ingredients are not unique. ' |
265
|
|
|
'{}'.format([s.path for s in scaffolding]) |
266
|
|
|
) |
267
|
|
|
|
268
|
|
|
return scaffolding_ret |
269
|
|
|
|
270
|
|
|
|
271
|
|
|
def gather_ingredients_topological(ingredient): |
272
|
|
|
sub_ingredients = defaultdict(int) |
273
|
|
|
for sub_ing, depth in ingredient.traverse_ingredients(): |
274
|
|
|
sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth) |
275
|
|
|
return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x]) |
276
|
|
|
|
277
|
|
|
|
278
|
|
|
def get_config_modifications(scaffolding): |
279
|
|
|
config_modifications = ConfigSummary() |
280
|
|
|
for sc_path, scaffold in scaffolding.items(): |
281
|
|
|
config_modifications.update_add(scaffold.config_mods, path=sc_path) |
282
|
|
|
return config_modifications |
283
|
|
|
|
284
|
|
|
|
285
|
|
|
def get_command(scaffolding, command_path): |
286
|
|
|
path, _, command_name = command_path.rpartition('.') |
287
|
|
|
if path not in scaffolding: |
288
|
|
|
raise KeyError('Ingredient for command "%s" not found.' % command_path) |
289
|
|
|
|
290
|
|
|
if command_name in scaffolding[path].commands: |
291
|
|
|
return scaffolding[path].commands[command_name] |
292
|
|
|
else: |
293
|
|
|
if path: |
294
|
|
|
raise KeyError('Command "%s" not found in ingredient "%s"' % |
295
|
|
|
(command_name, path)) |
296
|
|
|
else: |
297
|
|
|
raise KeyError('Command "%s" not found' % command_name) |
298
|
|
|
|
299
|
|
|
|
300
|
|
|
def find_best_match(path, prefixes): |
301
|
|
|
"""Find the Ingredient that shares the longest prefix with path.""" |
302
|
|
|
path_parts = path.split('.') |
303
|
|
|
for p in prefixes: |
304
|
|
|
if len(p) <= len(path_parts) and p == path_parts[:len(p)]: |
305
|
|
|
return '.'.join(p), '.'.join(path_parts[len(p):]) |
306
|
|
|
return '', path |
307
|
|
|
|
308
|
|
|
|
309
|
|
|
def distribute_presets(prefixes, scaffolding, config_updates): |
310
|
|
|
for path, value in iterate_flattened(config_updates): |
311
|
|
|
scaffold_name, suffix = find_best_match(path, prefixes) |
312
|
|
|
scaff = scaffolding[scaffold_name] |
313
|
|
|
set_by_dotted_path(scaff.presets, suffix, value) |
314
|
|
|
|
315
|
|
|
|
316
|
|
|
def distribute_config_updates(prefixes, scaffolding, config_updates): |
317
|
|
|
for path, value in iterate_flattened(config_updates): |
318
|
|
|
scaffold_name, suffix = find_best_match(path, prefixes) |
319
|
|
|
scaff = scaffolding[scaffold_name] |
320
|
|
|
set_by_dotted_path(scaff.config_updates, suffix, value) |
321
|
|
|
|
322
|
|
|
|
323
|
|
|
def get_scaffolding_and_config_name(named_config, scaffolding): |
324
|
|
|
if os.path.exists(named_config): |
325
|
|
|
path, cfg_name = '', named_config |
326
|
|
|
else: |
327
|
|
|
path, _, cfg_name = named_config.rpartition('.') |
328
|
|
|
|
329
|
|
|
if path not in scaffolding: |
330
|
|
|
raise KeyError('Ingredient for named config "{}" not found' |
331
|
|
|
.format(named_config)) |
332
|
|
|
scaff = scaffolding[path] |
333
|
|
|
return scaff, cfg_name |
334
|
|
|
|
335
|
|
|
|
336
|
|
|
def create_run(experiment, command_name, config_updates=None, |
337
|
|
|
named_configs=(), force=False): |
338
|
|
|
|
339
|
|
|
sorted_ingredients = gather_ingredients_topological(experiment) |
340
|
|
|
scaffolding = create_scaffolding(experiment, sorted_ingredients) |
341
|
|
|
# get all split non-empty prefixes sorted from deepest to shallowest |
342
|
|
|
prefixes = sorted([s.split('.') for s in scaffolding if s != ''], |
343
|
|
|
reverse=True, key=lambda p: len(p)) |
344
|
|
|
|
345
|
|
|
# --------- configuration process ------------------- |
346
|
|
|
|
347
|
|
|
# Phase 1: Config updates |
348
|
|
|
config_updates = config_updates or {} |
349
|
|
|
config_updates = convert_to_nested_dict(config_updates) |
350
|
|
|
root_logger, run_logger = initialize_logging(experiment, scaffolding) |
351
|
|
|
distribute_config_updates(prefixes, scaffolding, config_updates) |
352
|
|
|
|
353
|
|
|
# Phase 2: Named Configs |
354
|
|
|
for ncfg in named_configs: |
355
|
|
|
scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) |
356
|
|
|
scaff.gather_fallbacks() |
357
|
|
|
ncfg_updates = scaff.run_named_config(cfg_name) |
358
|
|
|
distribute_presets(prefixes, scaffolding, ncfg_updates) |
359
|
|
|
for ncfg_key, value in iterate_flattened(ncfg_updates): |
360
|
|
|
set_by_dotted_path(config_updates, |
361
|
|
|
join_paths(scaff.path, ncfg_key), |
362
|
|
|
value) |
363
|
|
|
|
364
|
|
|
distribute_config_updates(prefixes, scaffolding, config_updates) |
365
|
|
|
|
366
|
|
|
# Phase 3: Normal config scopes |
367
|
|
|
for scaffold in scaffolding.values(): |
368
|
|
|
scaffold.gather_fallbacks() |
369
|
|
|
scaffold.set_up_config() |
370
|
|
|
|
371
|
|
|
# update global config |
372
|
|
|
config = get_configuration(scaffolding) |
373
|
|
|
# run config hooks |
374
|
|
|
config_updates = scaffold.run_config_hooks(config, config_updates, |
375
|
|
|
command_name, run_logger) |
376
|
|
|
|
377
|
|
|
# Phase 4: finalize seeding |
378
|
|
|
for scaffold in reversed(list(scaffolding.values())): |
379
|
|
|
scaffold.set_up_seed() # partially recursive |
380
|
|
|
|
381
|
|
|
config = get_configuration(scaffolding) |
382
|
|
|
config_modifications = get_config_modifications(scaffolding) |
383
|
|
|
|
384
|
|
|
# ---------------------------------------------------- |
385
|
|
|
|
386
|
|
|
experiment_info = experiment.get_experiment_info() |
387
|
|
|
host_info = get_host_info() |
388
|
|
|
main_function = get_command(scaffolding, command_name) |
389
|
|
|
pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] |
390
|
|
|
post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] |
391
|
|
|
|
392
|
|
|
run = Run(config, config_modifications, main_function, |
393
|
|
|
copy(experiment.observers), root_logger, run_logger, |
394
|
|
|
experiment_info, host_info, pre_runs, post_runs, |
395
|
|
|
experiment.captured_out_filter) |
396
|
|
|
|
397
|
|
|
if hasattr(main_function, 'unobserved'): |
398
|
|
|
run.unobserved = main_function.unobserved |
399
|
|
|
|
400
|
|
|
run.force = force |
401
|
|
|
|
402
|
|
|
for scaffold in scaffolding.values(): |
403
|
|
|
scaffold.finalize_initialization(run=run) |
404
|
|
|
|
405
|
|
|
return run |
406
|
|
|
|