Completed
Push — master ( 28fb1c...6c3661 )
by Klaus
34s
created

Scaffold.get_config_updates_recursive()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
c 0
b 0
f 0
dl 0
loc 5
rs 9.4285
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import os
6
from collections import OrderedDict, defaultdict
7
from copy import copy, deepcopy
8
9
from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize,
10
                           load_config_file, undogmatize)
11
from sacred.config.config_summary import ConfigSummary
12
from sacred.host_info import get_host_info
13
from sacred.randomness import create_rnd, get_seed
14
from sacred.run import Run
15
from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger,
16
                          get_by_dotted_path, is_prefix,
17
                          iterate_flattened, set_by_dotted_path,
18
                          recursive_update, iter_prefixes, join_paths)
19
20
__sacred__ = True  # marks files that should be filtered from stack traces
21
22
23
class Scaffold(object):
24
    def __init__(self, config_scopes, subrunners, path, captured_functions,
25
                 commands, named_configs, config_hooks, generate_seed):
26
        self.config_scopes = config_scopes
27
        self.named_configs = named_configs
28
        self.subrunners = subrunners
29
        self.path = path
30
        self.generate_seed = generate_seed
31
        self.config_hooks = config_hooks
32
        self.config_updates = {}
33
        self.named_configs_to_use = []
34
        self.config = {}
35
        self.fallback = None
36
        self.presets = {}
37
        self.fixture = None  # TODO: rename
38
        self.logger = None
39
        self.seed = None
40
        self.rnd = None
41
        self._captured_functions = captured_functions
42
        self.commands = commands
43
        self.config_mods = None
44
        self.summaries = []
45
        self.captured_args = {join_paths(cf.prefix, n)
46
                              for cf in self._captured_functions
47
                              for n in cf.signature.arguments}
48
        self.captured_args.add('__doc__')  # allow setting the config docstring
49
50
    def set_up_seed(self, rnd=None):
51
        if self.seed is not None:
52
            return
53
54
        self.seed = self.config.get('seed')
55
        if self.seed is None:
56
            self.seed = get_seed(rnd)
57
58
        self.rnd = create_rnd(self.seed)
59
60
        if self.generate_seed:
61
            self.config['seed'] = self.seed
62
63
        if 'seed' in self.config and 'seed' in self.config_mods.added:
64
            self.config_mods.modified.add('seed')
65
            self.config_mods.added -= {'seed'}
66
67
        # Hierarchically set the seed of proper subrunners
68
        for subrunner_path, subrunner in reversed(list(
69
                self.subrunners.items())):
70
            if is_prefix(self.path, subrunner_path):
71
                subrunner.set_up_seed(self.rnd)
72
73
    def gather_fallbacks(self):
74
        fallback = {}
75
        for sr_path, subrunner in self.subrunners.items():
76
            if self.path and is_prefix(self.path, sr_path):
77
                path = sr_path[len(self.path):].strip('.')
78
                set_by_dotted_path(fallback, path, subrunner.config)
79
            else:
80
                set_by_dotted_path(fallback, sr_path, subrunner.config)
81
82
        # dogmatize to make the subrunner configurations read-only
83
        self.fallback = dogmatize(fallback)
84
        self.fallback.revelation()
85
86
    def run_named_config(self, config_name):
87
        if os.path.exists(config_name):
88
            nc = ConfigDict(load_config_file(config_name))
89
        else:
90
            nc = self.named_configs[config_name]
91
92
        cfg = nc(fixed=self.get_config_updates_recursive(),
93
                 preset=self.presets,
94
                 fallback=self.fallback)
95
96
        return undogmatize(cfg)
97
98
    def set_up_config(self):
99
        self.config, self.summaries = chain_evaluate_config_scopes(
100
            self.config_scopes,
101
            fixed=self.config_updates,
102
            preset=self.config,
103
            fallback=self.fallback)
104
105
        self.get_config_modifications()
106
107
    def run_config_hooks(self, config, config_updates, command_name, logger):
108
        final_cfg_updates = {}
109
        for ch in self.config_hooks:
110
            cfg_upup = ch(deepcopy(config), command_name, logger)
111
            if cfg_upup:
112
                recursive_update(final_cfg_updates, cfg_upup)
113
        recursive_update(final_cfg_updates, config_updates)
114
        return final_cfg_updates
115
116
    def get_config_modifications(self):
117
        self.config_mods = ConfigSummary(
118
            added={key
119
                   for key, value in iterate_flattened(self.config_updates)})
120
        for cfg_summary in self.summaries:
121
            self.config_mods.update_from(cfg_summary)
122
123
    def get_config_updates_recursive(self):
124
        config_updates = self.config_updates.copy()
125
        for sr_path, subrunner in self.subrunners.items():
126
            config_updates[sr_path] = subrunner.get_config_updates_recursive()
127
        return config_updates
128
129
    def get_fixture(self):
130
        if self.fixture is not None:
131
            return self.fixture
132
133
        def get_fixture_recursive(runner):
134
            for sr_path, subrunner in runner.subrunners.items():
135
                # I am not sure if it is necessary to trigger all
136
                subrunner.get_fixture()
137
                get_fixture_recursive(subrunner)
138
                sub_fix = copy(subrunner.config)
139
                sub_path = sr_path
140
                if is_prefix(self.path, sub_path):
141
                    sub_path = sr_path[len(self.path):].strip('.')
142
                # Note: This might fail if we allow non-dict fixtures
143
                set_by_dotted_path(self.fixture, sub_path, sub_fix)
144
145
        self.fixture = copy(self.config)
146
        get_fixture_recursive(self)
147
148
        return self.fixture
149
150
    def finalize_initialization(self, run):
151
        # look at seed again, because it might have changed during the
152
        # configuration process
153
        if 'seed' in self.config:
154
            self.seed = self.config['seed']
155
        self.rnd = create_rnd(self.seed)
156
157
        for cfunc in self._captured_functions:
158
            cfunc.logger = self.logger.getChild(cfunc.__name__)
159
            cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix,
160
                                              default={})
161
            seed = get_seed(self.rnd)
162
            cfunc.rnd = create_rnd(seed)
163
            cfunc.run = run
164
165
        if not run.force:
166
            self._warn_about_suspicious_changes()
167
168
    def _warn_about_suspicious_changes(self):
169
        for add in sorted(self.config_mods.added):
170
            if not set(iter_prefixes(add)).intersection(self.captured_args):
171
                raise KeyError('Added a new config entry "{}" that is not used'
172
                               ' anywhere'.format(add))
173
            else:
174
                self.logger.warning('Added new config entry: "%s"' % add)
175
176
        for key, (type_old, type_new) in self.config_mods.typechanged.items():
177
            if type_old in (int, float) and type_new in (int, float):
178
                continue
179
            self.logger.warning(
180
                'Changed type of config entry "%s" from %s to %s' %
181
                (key, type_old.__name__, type_new.__name__))
182
183
        for cfg_summary in self.summaries:
184
            for key in cfg_summary.ignored_fallbacks:
185
                self.logger.warning(
186
                    'Ignored attempt to set value of "%s", because it is an '
187
                    'ingredient.' % key
188
                )
189
190
    def __repr__(self):
191
        return "<Scaffold: '{}'>".format(self.path)
192
193
194
def get_configuration(scaffolding):
195
    config = {}
196
    for sc_path, scaffold in reversed(list(scaffolding.items())):
197
        if not scaffold.config:
198
            continue
199
        if sc_path:
200
            set_by_dotted_path(config, sc_path, scaffold.config)
201
        else:
202
            config.update(scaffold.config)
203
    return config
204
205
206
def distribute_named_configs(scaffolding, named_configs):
207
    for ncfg in named_configs:
208
        if os.path.exists(ncfg):
209
            scaffolding[''].use_named_config(ncfg)
210
        else:
211
            path, _, cfg_name = ncfg.rpartition('.')
212
            if path not in scaffolding:
213
                raise KeyError('Ingredient for named config "{}" not found'
214
                               .format(ncfg))
215
            scaffolding[path].use_named_config(cfg_name)
216
217
218
def initialize_logging(experiment, scaffolding):
219
    if experiment.logger is None:
220
        root_logger = create_basic_stream_logger()
221
    else:
222
        root_logger = experiment.logger
223
224
    for sc_path, scaffold in scaffolding.items():
225
        if sc_path:
226
            scaffold.logger = root_logger.getChild(sc_path)
227
        else:
228
            scaffold.logger = root_logger
229
230
    return root_logger, root_logger.getChild(experiment.path)
231
232
233
def create_scaffolding(experiment, sorted_ingredients):
234
    scaffolding = OrderedDict()
235
    for ingredient in sorted_ingredients[:-1]:
236
        scaffolding[ingredient] = Scaffold(
237
            config_scopes=ingredient.configurations,
238
            subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m])
239
                                    for m in ingredient.ingredients]),
240
            path=ingredient.path if ingredient != experiment else '',
241
            captured_functions=ingredient.captured_functions,
242
            commands=ingredient.commands,
243
            named_configs=ingredient.named_configs,
244
            config_hooks=ingredient.config_hooks,
245
            generate_seed=False)
246
247
    scaffolding[experiment] = Scaffold(
248
        experiment.configurations,
249
        subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m])
250
                                for m in experiment.ingredients]),
251
        path=experiment.path if experiment != experiment else '',
252
        captured_functions=experiment.captured_functions,
253
        commands=experiment.commands,
254
        named_configs=experiment.named_configs,
255
        config_hooks=experiment.config_hooks,
256
        generate_seed=True)
257
258
    scaffolding_ret = OrderedDict([
259
        (sc.path, sc)
260
        for sc in scaffolding.values()
261
    ])
262
    if len(scaffolding_ret) != len(scaffolding):
263
        raise ValueError(
264
            'The pathes of the ingredients are not unique. '
265
            '{}'.format([s.path for s in scaffolding])
266
        )
267
268
    return scaffolding_ret
269
270
271
def gather_ingredients_topological(ingredient):
272
    sub_ingredients = defaultdict(int)
273
    for sub_ing, depth in ingredient.traverse_ingredients():
274
        sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth)
275
    return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x])
276
277
278
def get_config_modifications(scaffolding):
279
    config_modifications = ConfigSummary()
280
    for sc_path, scaffold in scaffolding.items():
281
        config_modifications.update_add(scaffold.config_mods, path=sc_path)
282
    return config_modifications
283
284
285
def get_command(scaffolding, command_path):
286
    path, _, command_name = command_path.rpartition('.')
287
    if path not in scaffolding:
288
        raise KeyError('Ingredient for command "%s" not found.' % command_path)
289
290
    if command_name in scaffolding[path].commands:
291
        return scaffolding[path].commands[command_name]
292
    else:
293
        if path:
294
            raise KeyError('Command "%s" not found in ingredient "%s"' %
295
                           (command_name, path))
296
        else:
297
            raise KeyError('Command "%s" not found' % command_name)
298
299
300
def find_best_match(path, prefixes):
301
    """Find the Ingredient that shares the longest prefix with path."""
302
    path_parts = path.split('.')
303
    for p in prefixes:
304
        if len(p) <= len(path_parts) and p == path_parts[:len(p)]:
305
            return '.'.join(p), '.'.join(path_parts[len(p):])
306
    return '', path
307
308
309
def distribute_presets(prefixes, scaffolding, config_updates):
310
    for path, value in iterate_flattened(config_updates):
311
        scaffold_name, suffix = find_best_match(path, prefixes)
312
        scaff = scaffolding[scaffold_name]
313
        set_by_dotted_path(scaff.presets, suffix, value)
314
315
316
def distribute_config_updates(prefixes, scaffolding, config_updates):
317
    for path, value in iterate_flattened(config_updates):
318
        scaffold_name, suffix = find_best_match(path, prefixes)
319
        scaff = scaffolding[scaffold_name]
320
        set_by_dotted_path(scaff.config_updates, suffix, value)
321
322
323
def get_scaffolding_and_config_name(named_config, scaffolding):
324
    if os.path.exists(named_config):
325
        path, cfg_name = '', named_config
326
    else:
327
        path, _, cfg_name = named_config.rpartition('.')
328
329
        if path not in scaffolding:
330
            raise KeyError('Ingredient for named config "{}" not found'
331
                           .format(named_config))
332
    scaff = scaffolding[path]
333
    return scaff, cfg_name
334
335
336
def create_run(experiment, command_name, config_updates=None,
337
               named_configs=(), force=False):
338
339
    sorted_ingredients = gather_ingredients_topological(experiment)
340
    scaffolding = create_scaffolding(experiment, sorted_ingredients)
341
    # get all split non-empty prefixes sorted from deepest to shallowest
342
    prefixes = sorted([s.split('.') for s in scaffolding if s != ''],
343
                      reverse=True, key=lambda p: len(p))
344
345
    # --------- configuration process -------------------
346
347
    # Phase 1: Config updates
348
    config_updates = config_updates or {}
349
    config_updates = convert_to_nested_dict(config_updates)
350
    root_logger, run_logger = initialize_logging(experiment, scaffolding)
351
    distribute_config_updates(prefixes, scaffolding, config_updates)
352
353
    # Phase 2: Named Configs
354
    for ncfg in named_configs:
355
        scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding)
356
        scaff.gather_fallbacks()
357
        ncfg_updates = scaff.run_named_config(cfg_name)
358
        distribute_presets(prefixes, scaffolding, ncfg_updates)
359
        for ncfg_key, value in iterate_flattened(ncfg_updates):
360
            set_by_dotted_path(config_updates,
361
                               join_paths(scaff.path, ncfg_key),
362
                               value)
363
364
    distribute_config_updates(prefixes, scaffolding, config_updates)
365
366
    # Phase 3: Normal config scopes
367
    for scaffold in scaffolding.values():
368
        scaffold.gather_fallbacks()
369
        scaffold.set_up_config()
370
371
        # update global config
372
        config = get_configuration(scaffolding)
373
        # run config hooks
374
        config_updates = scaffold.run_config_hooks(config, config_updates,
375
                                                   command_name, run_logger)
376
377
    # Phase 4: finalize seeding
378
    for scaffold in reversed(list(scaffolding.values())):
379
        scaffold.set_up_seed()  # partially recursive
380
381
    config = get_configuration(scaffolding)
382
    config_modifications = get_config_modifications(scaffolding)
383
384
    # ----------------------------------------------------
385
386
    experiment_info = experiment.get_experiment_info()
387
    host_info = get_host_info()
388
    main_function = get_command(scaffolding, command_name)
389
    pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks]
390
    post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks]
391
392
    run = Run(config, config_modifications, main_function,
393
              copy(experiment.observers), root_logger, run_logger,
394
              experiment_info, host_info, pre_runs, post_runs,
395
              experiment.captured_out_filter)
396
397
    if hasattr(main_function, 'unobserved'):
398
        run.unobserved = main_function.unobserved
399
400
    run.force = force
401
402
    for scaffold in scaffolding.values():
403
        scaffold.finalize_initialization(run=run)
404
405
    return run
406