Completed
Push — master ( 5d8e11...d66198 )
by Klaus
03:33
created

create_scaffolding()   C

Complexity

Conditions 7

Size

Total Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 7
dl 0
loc 36
rs 5.5
c 1
b 0
f 0
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import os
6
from collections import OrderedDict, defaultdict
7
from copy import copy, deepcopy
8
9
from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize,
10
                           load_config_file, undogmatize)
11
from sacred.config.config_summary import ConfigSummary
12
from sacred.host_info import get_host_info
13
from sacred.randomness import create_rnd, get_seed
14
from sacred.run import Run
15
from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger,
16
                          get_by_dotted_path, is_prefix, rel_path,
17
                          iterate_flattened, set_by_dotted_path,
18
                          recursive_update, iter_prefixes, join_paths)
19
20
21
class Scaffold(object):
22
    def __init__(self, config_scopes, subrunners, path, captured_functions,
23
                 commands, named_configs, config_hooks, generate_seed):
24
        self.config_scopes = config_scopes
25
        self.named_configs = named_configs
26
        self.subrunners = subrunners
27
        self.path = path
28
        self.generate_seed = generate_seed
29
        self.config_hooks = config_hooks
30
        self.config_updates = {}
31
        self.named_configs_to_use = []
32
        self.config = {}
33
        self.fallback = None
34
        self.presets = {}
35
        self.fixture = None  # TODO: rename
36
        self.logger = None
37
        self.seed = None
38
        self.rnd = None
39
        self._captured_functions = captured_functions
40
        self.commands = commands
41
        self.config_mods = None
42
        self.summaries = []
43
        self.captured_args = {join_paths(cf.prefix, n)
44
                              for cf in self._captured_functions
45
                              for n in cf.signature.arguments}
46
        self.captured_args.add('__doc__')  # allow setting the config docstring
47
48
    def set_up_seed(self, rnd=None):
49
        if self.seed is not None:
50
            return
51
52
        self.seed = self.config.get('seed')
53
        if self.seed is None:
54
            self.seed = get_seed(rnd)
55
56
        self.rnd = create_rnd(self.seed)
57
58
        if self.generate_seed:
59
            self.config['seed'] = self.seed
60
61
        if 'seed' in self.config and 'seed' in self.config_mods.added:
62
            self.config_mods.modified.add('seed')
63
            self.config_mods.added -= {'seed'}
64
65
        # Hierarchically set the seed of proper subrunners
66
        for subrunner_path, subrunner in reversed(list(
67
                self.subrunners.items())):
68
            if is_prefix(self.path, subrunner_path):
69
                subrunner.set_up_seed(self.rnd)
70
71
    def gather_fallbacks(self):
72
        fallback = {'_log': self.logger}
73
        for sr_path, subrunner in self.subrunners.items():
74
            if self.path and is_prefix(self.path, sr_path):
75
                path = sr_path[len(self.path):].strip('.')
76
                set_by_dotted_path(fallback, path, subrunner.config)
77
            else:
78
                set_by_dotted_path(fallback, sr_path, subrunner.config)
79
80
        # dogmatize to make the subrunner configurations read-only
81
        self.fallback = dogmatize(fallback)
82
        self.fallback.revelation()
83
84
    def run_named_config(self, config_name):
85
        if os.path.exists(config_name):
86
            nc = ConfigDict(load_config_file(config_name))
87
        else:
88
            nc = self.named_configs[config_name]
89
90
        cfg = nc(fixed=self.get_config_updates_recursive(),
91
                 preset=self.presets,
92
                 fallback=self.fallback)
93
94
        return undogmatize(cfg)
95
96
    def set_up_config(self):
97
        self.config, self.summaries = chain_evaluate_config_scopes(
98
            self.config_scopes,
99
            fixed=self.config_updates,
100
            preset=self.config,
101
            fallback=self.fallback)
102
103
        self.get_config_modifications()
104
105
    def run_config_hooks(self, config, config_updates, command_name, logger):
106
        final_cfg_updates = {}
107
        for ch in self.config_hooks:
108
            cfg_upup = ch(deepcopy(config), command_name, logger)
109
            if cfg_upup:
110
                recursive_update(final_cfg_updates, cfg_upup)
111
        recursive_update(final_cfg_updates, config_updates)
112
        return final_cfg_updates
113
114
    def get_config_modifications(self):
115
        self.config_mods = ConfigSummary(
116
            added={key
117
                   for key, value in iterate_flattened(self.config_updates)})
118
        for cfg_summary in self.summaries:
119
            self.config_mods.update_from(cfg_summary)
120
121
    def get_config_updates_recursive(self):
122
        config_updates = self.config_updates.copy()
123
        for sr_path, subrunner in self.subrunners.items():
124
            if not is_prefix(self.path, sr_path):
125
                continue
126
            update = subrunner.get_config_updates_recursive()
127
            if update:
128
                config_updates[rel_path(self.path, sr_path)] = update
129
        return config_updates
130
131
    def get_fixture(self):
132
        if self.fixture is not None:
133
            return self.fixture
134
135
        def get_fixture_recursive(runner):
136
            for sr_path, subrunner in runner.subrunners.items():
137
                # I am not sure if it is necessary to trigger all
138
                subrunner.get_fixture()
139
                get_fixture_recursive(subrunner)
140
                sub_fix = copy(subrunner.config)
141
                sub_path = sr_path
142
                if is_prefix(self.path, sub_path):
143
                    sub_path = sr_path[len(self.path):].strip('.')
144
                # Note: This might fail if we allow non-dict fixtures
145
                set_by_dotted_path(self.fixture, sub_path, sub_fix)
146
147
        self.fixture = copy(self.config)
148
        get_fixture_recursive(self)
149
150
        return self.fixture
151
152
    def finalize_initialization(self, run):
153
        # look at seed again, because it might have changed during the
154
        # configuration process
155
        if 'seed' in self.config:
156
            self.seed = self.config['seed']
157
        self.rnd = create_rnd(self.seed)
158
159
        for cfunc in self._captured_functions:
160
            cfunc.logger = self.logger.getChild(cfunc.__name__)
161
            cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix,
162
                                              default={})
163
            seed = get_seed(self.rnd)
164
            cfunc.rnd = create_rnd(seed)
165
            cfunc.run = run
166
167
        if not run.force:
168
            self._warn_about_suspicious_changes()
169
170
    def _warn_about_suspicious_changes(self):
171
        for add in sorted(self.config_mods.added):
172
            if not set(iter_prefixes(add)).intersection(self.captured_args):
173
                raise KeyError('Added a new config entry "{}" that is not used'
174
                               ' anywhere'.format(add))
175
            else:
176
                self.logger.warning('Added new config entry: "%s"' % add)
177
178
        for key, (type_old, type_new) in self.config_mods.typechanged.items():
179
            if type_old in (int, float) and type_new in (int, float):
180
                continue
181
            self.logger.warning(
182
                'Changed type of config entry "%s" from %s to %s' %
183
                (key, type_old.__name__, type_new.__name__))
184
185
        for cfg_summary in self.summaries:
186
            for key in cfg_summary.ignored_fallbacks:
187
                self.logger.warning(
188
                    'Ignored attempt to set value of "%s", because it is an '
189
                    'ingredient.' % key
190
                )
191
192
    def __repr__(self):
193
        return "<Scaffold: '{}'>".format(self.path)
194
195
196
def get_configuration(scaffolding):
197
    config = {}
198
    for sc_path, scaffold in reversed(list(scaffolding.items())):
199
        if not scaffold.config:
200
            continue
201
        if sc_path:
202
            set_by_dotted_path(config, sc_path, scaffold.config)
203
        else:
204
            config.update(scaffold.config)
205
    return config
206
207
208
def distribute_named_configs(scaffolding, named_configs):
209
    for ncfg in named_configs:
210
        if os.path.exists(ncfg):
211
            scaffolding[''].use_named_config(ncfg)
212
        else:
213
            path, _, cfg_name = ncfg.rpartition('.')
214
            if path not in scaffolding:
215
                raise KeyError('Ingredient for named config "{}" not found'
216
                               .format(ncfg))
217
            scaffolding[path].use_named_config(cfg_name)
218
219
220
def initialize_logging(experiment, scaffolding, log_level=None):
221
    if experiment.logger is None:
222
        root_logger = create_basic_stream_logger()
223
    else:
224
        root_logger = experiment.logger
225
226
    for sc_path, scaffold in scaffolding.items():
227
        if sc_path:
228
            scaffold.logger = root_logger.getChild(sc_path)
229
        else:
230
            scaffold.logger = root_logger
231
232
    # set log level
233
    if log_level is not None:
234
        try:
235
            lvl = int(log_level)
236
        except ValueError:
237
            lvl = log_level
238
        root_logger.setLevel(lvl)
239
240
    return root_logger, root_logger.getChild(experiment.path)
241
242
243
def create_scaffolding(experiment, sorted_ingredients):
244
    scaffolding = OrderedDict()
245
    for ingredient in sorted_ingredients[:-1]:
246
        scaffolding[ingredient] = Scaffold(
247
            config_scopes=ingredient.configurations,
248
            subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m])
249
                                    for m in ingredient.ingredients]),
250
            path=ingredient.path,
251
            captured_functions=ingredient.captured_functions,
252
            commands=ingredient.commands,
253
            named_configs=ingredient.named_configs,
254
            config_hooks=ingredient.config_hooks,
255
            generate_seed=False)
256
257
    scaffolding[experiment] = Scaffold(
258
        experiment.configurations,
259
        subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m])
260
                                for m in experiment.ingredients]),
261
        path='',
262
        captured_functions=experiment.captured_functions,
263
        commands=experiment.commands,
264
        named_configs=experiment.named_configs,
265
        config_hooks=experiment.config_hooks,
266
        generate_seed=True)
267
268
    scaffolding_ret = OrderedDict([
269
        (sc.path, sc)
270
        for sc in scaffolding.values()
271
    ])
272
    if len(scaffolding_ret) != len(scaffolding):
273
        raise ValueError(
274
            'The pathes of the ingredients are not unique. '
275
            '{}'.format([s.path for s in scaffolding])
276
        )
277
278
    return scaffolding_ret
279
280
281
def gather_ingredients_topological(ingredient):
282
    sub_ingredients = defaultdict(int)
283
    for sub_ing, depth in ingredient.traverse_ingredients():
284
        sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth)
285
    return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x])
286
287
288
def get_config_modifications(scaffolding):
289
    config_modifications = ConfigSummary()
290
    for sc_path, scaffold in scaffolding.items():
291
        config_modifications.update_add(scaffold.config_mods, path=sc_path)
292
    return config_modifications
293
294
295
def get_command(scaffolding, command_path):
296
    path, _, command_name = command_path.rpartition('.')
297
    if path not in scaffolding:
298
        raise KeyError('Ingredient for command "%s" not found.' % command_path)
299
300
    if command_name in scaffolding[path].commands:
301
        return scaffolding[path].commands[command_name]
302
    else:
303
        if path:
304
            raise KeyError('Command "%s" not found in ingredient "%s"' %
305
                           (command_name, path))
306
        else:
307
            raise KeyError('Command "%s" not found' % command_name)
308
309
310
def find_best_match(path, prefixes):
311
    """Find the Ingredient that shares the longest prefix with path."""
312
    path_parts = path.split('.')
313
    for p in prefixes:
314
        if len(p) <= len(path_parts) and p == path_parts[:len(p)]:
315
            return '.'.join(p), '.'.join(path_parts[len(p):])
316
    return '', path
317
318
319
def distribute_presets(prefixes, scaffolding, config_updates):
320
    for path, value in iterate_flattened(config_updates):
321
        scaffold_name, suffix = find_best_match(path, prefixes)
322
        scaff = scaffolding[scaffold_name]
323
        set_by_dotted_path(scaff.presets, suffix, value)
324
325
326
def distribute_config_updates(prefixes, scaffolding, config_updates):
327
    for path, value in iterate_flattened(config_updates):
328
        scaffold_name, suffix = find_best_match(path, prefixes)
329
        scaff = scaffolding[scaffold_name]
330
        set_by_dotted_path(scaff.config_updates, suffix, value)
331
332
333
def get_scaffolding_and_config_name(named_config, scaffolding):
334
    if os.path.exists(named_config):
335
        path, cfg_name = '', named_config
336
    else:
337
        path, _, cfg_name = named_config.rpartition('.')
338
339
        if path not in scaffolding:
340
            raise KeyError('Ingredient for named config "{}" not found'
341
                           .format(named_config))
342
    scaff = scaffolding[path]
343
    return scaff, cfg_name
344
345
346
def create_run(experiment, command_name, config_updates=None,
347
               named_configs=(), force=False, log_level=None):
348
349
    sorted_ingredients = gather_ingredients_topological(experiment)
350
    scaffolding = create_scaffolding(experiment, sorted_ingredients)
351
    # get all split non-empty prefixes sorted from deepest to shallowest
352
    prefixes = sorted([s.split('.') for s in scaffolding if s != ''],
353
                      reverse=True, key=lambda p: len(p))
354
355
    # --------- configuration process -------------------
356
357
    # Phase 1: Config updates
358
    config_updates = config_updates or {}
359
    config_updates = convert_to_nested_dict(config_updates)
360
    root_logger, run_logger = initialize_logging(experiment, scaffolding,
361
                                                 log_level)
362
    distribute_config_updates(prefixes, scaffolding, config_updates)
363
364
    # Phase 2: Named Configs
365
    for ncfg in named_configs:
366
        scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding)
367
        scaff.gather_fallbacks()
368
        ncfg_updates = scaff.run_named_config(cfg_name)
369
        distribute_presets(prefixes, scaffolding, ncfg_updates)
370
        for ncfg_key, value in iterate_flattened(ncfg_updates):
371
            set_by_dotted_path(config_updates,
372
                               join_paths(scaff.path, ncfg_key),
373
                               value)
374
375
    distribute_config_updates(prefixes, scaffolding, config_updates)
376
377
    # Phase 3: Normal config scopes
378
    for scaffold in scaffolding.values():
379
        scaffold.gather_fallbacks()
380
        scaffold.set_up_config()
381
382
        # update global config
383
        config = get_configuration(scaffolding)
384
        # run config hooks
385
        config_updates = scaffold.run_config_hooks(config, config_updates,
386
                                                   command_name, run_logger)
387
388
    # Phase 4: finalize seeding
389
    for scaffold in reversed(list(scaffolding.values())):
390
        scaffold.set_up_seed()  # partially recursive
391
392
    config = get_configuration(scaffolding)
393
    config_modifications = get_config_modifications(scaffolding)
394
395
    # ----------------------------------------------------
396
397
    experiment_info = experiment.get_experiment_info()
398
    host_info = get_host_info()
399
    main_function = get_command(scaffolding, command_name)
400
    pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks]
401
    post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks]
402
403
    run = Run(config, config_modifications, main_function,
404
              copy(experiment.observers), root_logger, run_logger,
405
              experiment_info, host_info, pre_runs, post_runs,
406
              experiment.captured_out_filter)
407
408
    if hasattr(main_function, 'unobserved'):
409
        run.unobserved = main_function.unobserved
410
411
    run.force = force
412
413
    for scaffold in scaffolding.values():
414
        scaffold.finalize_initialization(run=run)
415
416
    return run
417