|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# coding=utf-8 |
|
3
|
|
|
from __future__ import division, print_function, unicode_literals |
|
4
|
|
|
|
|
5
|
|
|
import os |
|
6
|
|
|
from collections import OrderedDict, defaultdict |
|
7
|
|
|
from copy import copy, deepcopy |
|
8
|
|
|
|
|
9
|
|
|
from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize, |
|
10
|
|
|
load_config_file, undogmatize) |
|
11
|
|
|
from sacred.config.config_summary import ConfigSummary |
|
12
|
|
|
from sacred.host_info import get_host_info |
|
13
|
|
|
from sacred.randomness import create_rnd, get_seed |
|
14
|
|
|
from sacred.run import Run |
|
15
|
|
|
from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger, |
|
16
|
|
|
get_by_dotted_path, is_prefix, |
|
17
|
|
|
iterate_flattened, set_by_dotted_path, |
|
18
|
|
|
recursive_update, iter_prefixes, join_paths) |
|
19
|
|
|
|
|
20
|
|
|
|
|
21
|
|
|
class Scaffold(object): |
|
22
|
|
|
def __init__(self, config_scopes, subrunners, path, captured_functions, |
|
23
|
|
|
commands, named_configs, config_hooks, generate_seed): |
|
24
|
|
|
self.config_scopes = config_scopes |
|
25
|
|
|
self.named_configs = named_configs |
|
26
|
|
|
self.subrunners = subrunners |
|
27
|
|
|
self.path = path |
|
28
|
|
|
self.generate_seed = generate_seed |
|
29
|
|
|
self.config_hooks = config_hooks |
|
30
|
|
|
self.config_updates = {} |
|
31
|
|
|
self.named_configs_to_use = [] |
|
32
|
|
|
self.config = {} |
|
33
|
|
|
self.fallback = None |
|
34
|
|
|
self.presets = {} |
|
35
|
|
|
self.fixture = None # TODO: rename |
|
36
|
|
|
self.logger = None |
|
37
|
|
|
self.seed = None |
|
38
|
|
|
self.rnd = None |
|
39
|
|
|
self._captured_functions = captured_functions |
|
40
|
|
|
self.commands = commands |
|
41
|
|
|
self.config_mods = None |
|
42
|
|
|
self.summaries = [] |
|
43
|
|
|
self.captured_args = {join_paths(cf.prefix, n) |
|
44
|
|
|
for cf in self._captured_functions |
|
45
|
|
|
for n in cf.signature.arguments} |
|
46
|
|
|
self.captured_args.add('__doc__') # allow setting the config docstring |
|
47
|
|
|
|
|
48
|
|
|
def set_up_seed(self, rnd=None): |
|
49
|
|
|
if self.seed is not None: |
|
50
|
|
|
return |
|
51
|
|
|
|
|
52
|
|
|
self.seed = self.config.get('seed') |
|
53
|
|
|
if self.seed is None: |
|
54
|
|
|
self.seed = get_seed(rnd) |
|
55
|
|
|
|
|
56
|
|
|
self.rnd = create_rnd(self.seed) |
|
57
|
|
|
|
|
58
|
|
|
if self.generate_seed: |
|
59
|
|
|
self.config['seed'] = self.seed |
|
60
|
|
|
|
|
61
|
|
|
if 'seed' in self.config and 'seed' in self.config_mods.added: |
|
62
|
|
|
self.config_mods.modified.add('seed') |
|
63
|
|
|
self.config_mods.added -= {'seed'} |
|
64
|
|
|
|
|
65
|
|
|
# Hierarchically set the seed of proper subrunners |
|
66
|
|
|
for subrunner_path, subrunner in reversed(list( |
|
67
|
|
|
self.subrunners.items())): |
|
68
|
|
|
if is_prefix(self.path, subrunner_path): |
|
69
|
|
|
subrunner.set_up_seed(self.rnd) |
|
70
|
|
|
|
|
71
|
|
|
def gather_fallbacks(self): |
|
72
|
|
|
fallback = {} |
|
73
|
|
|
for sr_path, subrunner in self.subrunners.items(): |
|
74
|
|
|
if self.path and is_prefix(self.path, sr_path): |
|
75
|
|
|
path = sr_path[len(self.path):].strip('.') |
|
76
|
|
|
set_by_dotted_path(fallback, path, subrunner.config) |
|
77
|
|
|
else: |
|
78
|
|
|
set_by_dotted_path(fallback, sr_path, subrunner.config) |
|
79
|
|
|
|
|
80
|
|
|
# dogmatize to make the subrunner configurations read-only |
|
81
|
|
|
self.fallback = dogmatize(fallback) |
|
82
|
|
|
self.fallback.revelation() |
|
83
|
|
|
|
|
84
|
|
|
def run_named_config(self, config_name): |
|
85
|
|
|
if os.path.exists(config_name): |
|
86
|
|
|
nc = ConfigDict(load_config_file(config_name)) |
|
87
|
|
|
else: |
|
88
|
|
|
nc = self.named_configs[config_name] |
|
89
|
|
|
|
|
90
|
|
|
cfg = nc(fixed=self.get_config_updates_recursive(), |
|
91
|
|
|
preset=self.presets, |
|
92
|
|
|
fallback=self.fallback) |
|
93
|
|
|
|
|
94
|
|
|
return undogmatize(cfg) |
|
95
|
|
|
|
|
96
|
|
|
def set_up_config(self): |
|
97
|
|
|
self.config, self.summaries = chain_evaluate_config_scopes( |
|
98
|
|
|
self.config_scopes, |
|
99
|
|
|
fixed=self.config_updates, |
|
100
|
|
|
preset=self.config, |
|
101
|
|
|
fallback=self.fallback) |
|
102
|
|
|
|
|
103
|
|
|
self.get_config_modifications() |
|
104
|
|
|
|
|
105
|
|
|
def run_config_hooks(self, config, config_updates, command_name, logger): |
|
106
|
|
|
final_cfg_updates = {} |
|
107
|
|
|
for ch in self.config_hooks: |
|
108
|
|
|
cfg_upup = ch(deepcopy(config), command_name, logger) |
|
109
|
|
|
if cfg_upup: |
|
110
|
|
|
recursive_update(final_cfg_updates, cfg_upup) |
|
111
|
|
|
recursive_update(final_cfg_updates, config_updates) |
|
112
|
|
|
return final_cfg_updates |
|
113
|
|
|
|
|
114
|
|
|
def get_config_modifications(self): |
|
115
|
|
|
self.config_mods = ConfigSummary( |
|
116
|
|
|
added={key |
|
117
|
|
|
for key, value in iterate_flattened(self.config_updates)}) |
|
118
|
|
|
for cfg_summary in self.summaries: |
|
119
|
|
|
self.config_mods.update_from(cfg_summary) |
|
120
|
|
|
|
|
121
|
|
|
def get_config_updates_recursive(self): |
|
122
|
|
|
config_updates = self.config_updates.copy() |
|
123
|
|
|
for sr_path, subrunner in self.subrunners.items(): |
|
124
|
|
|
config_updates[sr_path] = subrunner.get_config_updates_recursive() |
|
125
|
|
|
return config_updates |
|
126
|
|
|
|
|
127
|
|
|
def get_fixture(self): |
|
128
|
|
|
if self.fixture is not None: |
|
129
|
|
|
return self.fixture |
|
130
|
|
|
|
|
131
|
|
|
def get_fixture_recursive(runner): |
|
132
|
|
|
for sr_path, subrunner in runner.subrunners.items(): |
|
133
|
|
|
# I am not sure if it is necessary to trigger all |
|
134
|
|
|
subrunner.get_fixture() |
|
135
|
|
|
get_fixture_recursive(subrunner) |
|
136
|
|
|
sub_fix = copy(subrunner.config) |
|
137
|
|
|
sub_path = sr_path |
|
138
|
|
|
if is_prefix(self.path, sub_path): |
|
139
|
|
|
sub_path = sr_path[len(self.path):].strip('.') |
|
140
|
|
|
# Note: This might fail if we allow non-dict fixtures |
|
141
|
|
|
set_by_dotted_path(self.fixture, sub_path, sub_fix) |
|
142
|
|
|
|
|
143
|
|
|
self.fixture = copy(self.config) |
|
144
|
|
|
get_fixture_recursive(self) |
|
145
|
|
|
|
|
146
|
|
|
return self.fixture |
|
147
|
|
|
|
|
148
|
|
|
def finalize_initialization(self, run): |
|
149
|
|
|
# look at seed again, because it might have changed during the |
|
150
|
|
|
# configuration process |
|
151
|
|
|
if 'seed' in self.config: |
|
152
|
|
|
self.seed = self.config['seed'] |
|
153
|
|
|
self.rnd = create_rnd(self.seed) |
|
154
|
|
|
|
|
155
|
|
|
for cfunc in self._captured_functions: |
|
156
|
|
|
cfunc.logger = self.logger.getChild(cfunc.__name__) |
|
157
|
|
|
cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix, |
|
158
|
|
|
default={}) |
|
159
|
|
|
seed = get_seed(self.rnd) |
|
160
|
|
|
cfunc.rnd = create_rnd(seed) |
|
161
|
|
|
cfunc.run = run |
|
162
|
|
|
|
|
163
|
|
|
if not run.force: |
|
164
|
|
|
self._warn_about_suspicious_changes() |
|
165
|
|
|
|
|
166
|
|
|
def _warn_about_suspicious_changes(self): |
|
167
|
|
|
for add in sorted(self.config_mods.added): |
|
168
|
|
|
if not set(iter_prefixes(add)).intersection(self.captured_args): |
|
169
|
|
|
raise KeyError('Added a new config entry "{}" that is not used' |
|
170
|
|
|
' anywhere'.format(add)) |
|
171
|
|
|
else: |
|
172
|
|
|
self.logger.warning('Added new config entry: "%s"' % add) |
|
173
|
|
|
|
|
174
|
|
|
for key, (type_old, type_new) in self.config_mods.typechanged.items(): |
|
175
|
|
|
if type_old in (int, float) and type_new in (int, float): |
|
176
|
|
|
continue |
|
177
|
|
|
self.logger.warning( |
|
178
|
|
|
'Changed type of config entry "%s" from %s to %s' % |
|
179
|
|
|
(key, type_old.__name__, type_new.__name__)) |
|
180
|
|
|
|
|
181
|
|
|
for cfg_summary in self.summaries: |
|
182
|
|
|
for key in cfg_summary.ignored_fallbacks: |
|
183
|
|
|
self.logger.warning( |
|
184
|
|
|
'Ignored attempt to set value of "%s", because it is an ' |
|
185
|
|
|
'ingredient.' % key |
|
186
|
|
|
) |
|
187
|
|
|
|
|
188
|
|
|
def __repr__(self): |
|
189
|
|
|
return "<Scaffold: '{}'>".format(self.path) |
|
190
|
|
|
|
|
191
|
|
|
|
|
192
|
|
|
def get_configuration(scaffolding): |
|
193
|
|
|
config = {} |
|
194
|
|
|
for sc_path, scaffold in reversed(list(scaffolding.items())): |
|
195
|
|
|
if not scaffold.config: |
|
196
|
|
|
continue |
|
197
|
|
|
if sc_path: |
|
198
|
|
|
set_by_dotted_path(config, sc_path, scaffold.config) |
|
199
|
|
|
else: |
|
200
|
|
|
config.update(scaffold.config) |
|
201
|
|
|
return config |
|
202
|
|
|
|
|
203
|
|
|
|
|
204
|
|
|
def distribute_named_configs(scaffolding, named_configs): |
|
205
|
|
|
for ncfg in named_configs: |
|
206
|
|
|
if os.path.exists(ncfg): |
|
207
|
|
|
scaffolding[''].use_named_config(ncfg) |
|
208
|
|
|
else: |
|
209
|
|
|
path, _, cfg_name = ncfg.rpartition('.') |
|
210
|
|
|
if path not in scaffolding: |
|
211
|
|
|
raise KeyError('Ingredient for named config "{}" not found' |
|
212
|
|
|
.format(ncfg)) |
|
213
|
|
|
scaffolding[path].use_named_config(cfg_name) |
|
214
|
|
|
|
|
215
|
|
|
|
|
216
|
|
|
def initialize_logging(experiment, scaffolding, log_level=None): |
|
217
|
|
|
if experiment.logger is None: |
|
218
|
|
|
root_logger = create_basic_stream_logger() |
|
219
|
|
|
else: |
|
220
|
|
|
root_logger = experiment.logger |
|
221
|
|
|
|
|
222
|
|
|
for sc_path, scaffold in scaffolding.items(): |
|
223
|
|
|
if sc_path: |
|
224
|
|
|
scaffold.logger = root_logger.getChild(sc_path) |
|
225
|
|
|
else: |
|
226
|
|
|
scaffold.logger = root_logger |
|
227
|
|
|
|
|
228
|
|
|
# set log level |
|
229
|
|
|
if log_level is not None: |
|
230
|
|
|
try: |
|
231
|
|
|
lvl = int(log_level) |
|
232
|
|
|
except ValueError: |
|
233
|
|
|
lvl = log_level |
|
234
|
|
|
root_logger.setLevel(lvl) |
|
235
|
|
|
|
|
236
|
|
|
return root_logger, root_logger.getChild(experiment.path) |
|
237
|
|
|
|
|
238
|
|
|
|
|
239
|
|
|
def create_scaffolding(experiment, sorted_ingredients): |
|
240
|
|
|
scaffolding = OrderedDict() |
|
241
|
|
|
for ingredient in sorted_ingredients[:-1]: |
|
242
|
|
|
scaffolding[ingredient] = Scaffold( |
|
243
|
|
|
config_scopes=ingredient.configurations, |
|
244
|
|
|
subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) |
|
245
|
|
|
for m in ingredient.ingredients]), |
|
246
|
|
|
path=ingredient.path if ingredient != experiment else '', |
|
247
|
|
|
captured_functions=ingredient.captured_functions, |
|
248
|
|
|
commands=ingredient.commands, |
|
249
|
|
|
named_configs=ingredient.named_configs, |
|
250
|
|
|
config_hooks=ingredient.config_hooks, |
|
251
|
|
|
generate_seed=False) |
|
252
|
|
|
|
|
253
|
|
|
scaffolding[experiment] = Scaffold( |
|
254
|
|
|
experiment.configurations, |
|
255
|
|
|
subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m]) |
|
256
|
|
|
for m in experiment.ingredients]), |
|
257
|
|
|
path=experiment.path if experiment != experiment else '', |
|
258
|
|
|
captured_functions=experiment.captured_functions, |
|
259
|
|
|
commands=experiment.commands, |
|
260
|
|
|
named_configs=experiment.named_configs, |
|
261
|
|
|
config_hooks=experiment.config_hooks, |
|
262
|
|
|
generate_seed=True) |
|
263
|
|
|
|
|
264
|
|
|
scaffolding_ret = OrderedDict([ |
|
265
|
|
|
(sc.path, sc) |
|
266
|
|
|
for sc in scaffolding.values() |
|
267
|
|
|
]) |
|
268
|
|
|
if len(scaffolding_ret) != len(scaffolding): |
|
269
|
|
|
raise ValueError( |
|
270
|
|
|
'The pathes of the ingredients are not unique. ' |
|
271
|
|
|
'{}'.format([s.path for s in scaffolding]) |
|
272
|
|
|
) |
|
273
|
|
|
|
|
274
|
|
|
return scaffolding_ret |
|
275
|
|
|
|
|
276
|
|
|
|
|
277
|
|
|
def gather_ingredients_topological(ingredient): |
|
278
|
|
|
sub_ingredients = defaultdict(int) |
|
279
|
|
|
for sub_ing, depth in ingredient.traverse_ingredients(): |
|
280
|
|
|
sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth) |
|
281
|
|
|
return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x]) |
|
282
|
|
|
|
|
283
|
|
|
|
|
284
|
|
|
def get_config_modifications(scaffolding): |
|
285
|
|
|
config_modifications = ConfigSummary() |
|
286
|
|
|
for sc_path, scaffold in scaffolding.items(): |
|
287
|
|
|
config_modifications.update_add(scaffold.config_mods, path=sc_path) |
|
288
|
|
|
return config_modifications |
|
289
|
|
|
|
|
290
|
|
|
|
|
291
|
|
|
def get_command(scaffolding, command_path): |
|
292
|
|
|
path, _, command_name = command_path.rpartition('.') |
|
293
|
|
|
if path not in scaffolding: |
|
294
|
|
|
raise KeyError('Ingredient for command "%s" not found.' % command_path) |
|
295
|
|
|
|
|
296
|
|
|
if command_name in scaffolding[path].commands: |
|
297
|
|
|
return scaffolding[path].commands[command_name] |
|
298
|
|
|
else: |
|
299
|
|
|
if path: |
|
300
|
|
|
raise KeyError('Command "%s" not found in ingredient "%s"' % |
|
301
|
|
|
(command_name, path)) |
|
302
|
|
|
else: |
|
303
|
|
|
raise KeyError('Command "%s" not found' % command_name) |
|
304
|
|
|
|
|
305
|
|
|
|
|
306
|
|
|
def find_best_match(path, prefixes): |
|
307
|
|
|
"""Find the Ingredient that shares the longest prefix with path.""" |
|
308
|
|
|
path_parts = path.split('.') |
|
309
|
|
|
for p in prefixes: |
|
310
|
|
|
if len(p) <= len(path_parts) and p == path_parts[:len(p)]: |
|
311
|
|
|
return '.'.join(p), '.'.join(path_parts[len(p):]) |
|
312
|
|
|
return '', path |
|
313
|
|
|
|
|
314
|
|
|
|
|
315
|
|
|
def distribute_presets(prefixes, scaffolding, config_updates): |
|
316
|
|
|
for path, value in iterate_flattened(config_updates): |
|
317
|
|
|
scaffold_name, suffix = find_best_match(path, prefixes) |
|
318
|
|
|
scaff = scaffolding[scaffold_name] |
|
319
|
|
|
set_by_dotted_path(scaff.presets, suffix, value) |
|
320
|
|
|
|
|
321
|
|
|
|
|
322
|
|
|
def distribute_config_updates(prefixes, scaffolding, config_updates): |
|
323
|
|
|
for path, value in iterate_flattened(config_updates): |
|
324
|
|
|
scaffold_name, suffix = find_best_match(path, prefixes) |
|
325
|
|
|
scaff = scaffolding[scaffold_name] |
|
326
|
|
|
set_by_dotted_path(scaff.config_updates, suffix, value) |
|
327
|
|
|
|
|
328
|
|
|
|
|
329
|
|
|
def get_scaffolding_and_config_name(named_config, scaffolding): |
|
330
|
|
|
if os.path.exists(named_config): |
|
331
|
|
|
path, cfg_name = '', named_config |
|
332
|
|
|
else: |
|
333
|
|
|
path, _, cfg_name = named_config.rpartition('.') |
|
334
|
|
|
|
|
335
|
|
|
if path not in scaffolding: |
|
336
|
|
|
raise KeyError('Ingredient for named config "{}" not found' |
|
337
|
|
|
.format(named_config)) |
|
338
|
|
|
scaff = scaffolding[path] |
|
339
|
|
|
return scaff, cfg_name |
|
340
|
|
|
|
|
341
|
|
|
|
|
342
|
|
|
def create_run(experiment, command_name, config_updates=None, |
|
343
|
|
|
named_configs=(), force=False, log_level=None): |
|
344
|
|
|
|
|
345
|
|
|
sorted_ingredients = gather_ingredients_topological(experiment) |
|
346
|
|
|
scaffolding = create_scaffolding(experiment, sorted_ingredients) |
|
347
|
|
|
# get all split non-empty prefixes sorted from deepest to shallowest |
|
348
|
|
|
prefixes = sorted([s.split('.') for s in scaffolding if s != ''], |
|
349
|
|
|
reverse=True, key=lambda p: len(p)) |
|
350
|
|
|
|
|
351
|
|
|
# --------- configuration process ------------------- |
|
352
|
|
|
|
|
353
|
|
|
# Phase 1: Config updates |
|
354
|
|
|
config_updates = config_updates or {} |
|
355
|
|
|
config_updates = convert_to_nested_dict(config_updates) |
|
356
|
|
|
root_logger, run_logger = initialize_logging(experiment, scaffolding, |
|
357
|
|
|
log_level) |
|
358
|
|
|
distribute_config_updates(prefixes, scaffolding, config_updates) |
|
359
|
|
|
|
|
360
|
|
|
# Phase 2: Named Configs |
|
361
|
|
|
for ncfg in named_configs: |
|
362
|
|
|
scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) |
|
363
|
|
|
scaff.gather_fallbacks() |
|
364
|
|
|
ncfg_updates = scaff.run_named_config(cfg_name) |
|
365
|
|
|
distribute_presets(prefixes, scaffolding, ncfg_updates) |
|
366
|
|
|
for ncfg_key, value in iterate_flattened(ncfg_updates): |
|
367
|
|
|
set_by_dotted_path(config_updates, |
|
368
|
|
|
join_paths(scaff.path, ncfg_key), |
|
369
|
|
|
value) |
|
370
|
|
|
|
|
371
|
|
|
distribute_config_updates(prefixes, scaffolding, config_updates) |
|
372
|
|
|
|
|
373
|
|
|
# Phase 3: Normal config scopes |
|
374
|
|
|
for scaffold in scaffolding.values(): |
|
375
|
|
|
scaffold.gather_fallbacks() |
|
376
|
|
|
scaffold.set_up_config() |
|
377
|
|
|
|
|
378
|
|
|
# update global config |
|
379
|
|
|
config = get_configuration(scaffolding) |
|
380
|
|
|
# run config hooks |
|
381
|
|
|
config_updates = scaffold.run_config_hooks(config, config_updates, |
|
382
|
|
|
command_name, run_logger) |
|
383
|
|
|
|
|
384
|
|
|
# Phase 4: finalize seeding |
|
385
|
|
|
for scaffold in reversed(list(scaffolding.values())): |
|
386
|
|
|
scaffold.set_up_seed() # partially recursive |
|
387
|
|
|
|
|
388
|
|
|
config = get_configuration(scaffolding) |
|
389
|
|
|
config_modifications = get_config_modifications(scaffolding) |
|
390
|
|
|
|
|
391
|
|
|
# ---------------------------------------------------- |
|
392
|
|
|
|
|
393
|
|
|
experiment_info = experiment.get_experiment_info() |
|
394
|
|
|
host_info = get_host_info() |
|
395
|
|
|
main_function = get_command(scaffolding, command_name) |
|
396
|
|
|
pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] |
|
397
|
|
|
post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] |
|
398
|
|
|
|
|
399
|
|
|
run = Run(config, config_modifications, main_function, |
|
400
|
|
|
copy(experiment.observers), root_logger, run_logger, |
|
401
|
|
|
experiment_info, host_info, pre_runs, post_runs, |
|
402
|
|
|
experiment.captured_out_filter) |
|
403
|
|
|
|
|
404
|
|
|
if hasattr(main_function, 'unobserved'): |
|
405
|
|
|
run.unobserved = main_function.unobserved |
|
406
|
|
|
|
|
407
|
|
|
run.force = force |
|
408
|
|
|
|
|
409
|
|
|
for scaffold in scaffolding.values(): |
|
410
|
|
|
scaffold.finalize_initialization(run=run) |
|
411
|
|
|
|
|
412
|
|
|
return run |
|
413
|
|
|
|