Completed
Push — master ( 416f46...e9b36c )
by Klaus
28s
created

initialize_logging()   B

Complexity

Conditions 6

Size

Total Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 21
rs 7.8867
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import os
6
from collections import OrderedDict, defaultdict
7
from copy import copy, deepcopy
8
9
from sacred.config import (ConfigDict, chain_evaluate_config_scopes, dogmatize,
10
                           load_config_file, undogmatize)
11
from sacred.config.config_summary import ConfigSummary
12
from sacred.host_info import get_host_info
13
from sacred.randomness import create_rnd, get_seed
14
from sacred.run import Run
15
from sacred.utils import (convert_to_nested_dict, create_basic_stream_logger,
16
                          get_by_dotted_path, is_prefix,
17
                          iterate_flattened, set_by_dotted_path,
18
                          recursive_update, iter_prefixes, join_paths)
19
20
21
class Scaffold(object):
22
    def __init__(self, config_scopes, subrunners, path, captured_functions,
23
                 commands, named_configs, config_hooks, generate_seed):
24
        self.config_scopes = config_scopes
25
        self.named_configs = named_configs
26
        self.subrunners = subrunners
27
        self.path = path
28
        self.generate_seed = generate_seed
29
        self.config_hooks = config_hooks
30
        self.config_updates = {}
31
        self.named_configs_to_use = []
32
        self.config = {}
33
        self.fallback = None
34
        self.presets = {}
35
        self.fixture = None  # TODO: rename
36
        self.logger = None
37
        self.seed = None
38
        self.rnd = None
39
        self._captured_functions = captured_functions
40
        self.commands = commands
41
        self.config_mods = None
42
        self.summaries = []
43
        self.captured_args = {join_paths(cf.prefix, n)
44
                              for cf in self._captured_functions
45
                              for n in cf.signature.arguments}
46
        self.captured_args.add('__doc__')  # allow setting the config docstring
47
48
    def set_up_seed(self, rnd=None):
49
        if self.seed is not None:
50
            return
51
52
        self.seed = self.config.get('seed')
53
        if self.seed is None:
54
            self.seed = get_seed(rnd)
55
56
        self.rnd = create_rnd(self.seed)
57
58
        if self.generate_seed:
59
            self.config['seed'] = self.seed
60
61
        if 'seed' in self.config and 'seed' in self.config_mods.added:
62
            self.config_mods.modified.add('seed')
63
            self.config_mods.added -= {'seed'}
64
65
        # Hierarchically set the seed of proper subrunners
66
        for subrunner_path, subrunner in reversed(list(
67
                self.subrunners.items())):
68
            if is_prefix(self.path, subrunner_path):
69
                subrunner.set_up_seed(self.rnd)
70
71
    def gather_fallbacks(self):
72
        fallback = {}
73
        for sr_path, subrunner in self.subrunners.items():
74
            if self.path and is_prefix(self.path, sr_path):
75
                path = sr_path[len(self.path):].strip('.')
76
                set_by_dotted_path(fallback, path, subrunner.config)
77
            else:
78
                set_by_dotted_path(fallback, sr_path, subrunner.config)
79
80
        # dogmatize to make the subrunner configurations read-only
81
        self.fallback = dogmatize(fallback)
82
        self.fallback.revelation()
83
84
    def run_named_config(self, config_name):
85
        if os.path.exists(config_name):
86
            nc = ConfigDict(load_config_file(config_name))
87
        else:
88
            nc = self.named_configs[config_name]
89
90
        cfg = nc(fixed=self.get_config_updates_recursive(),
91
                 preset=self.presets,
92
                 fallback=self.fallback)
93
94
        return undogmatize(cfg)
95
96
    def set_up_config(self):
97
        self.config, self.summaries = chain_evaluate_config_scopes(
98
            self.config_scopes,
99
            fixed=self.config_updates,
100
            preset=self.config,
101
            fallback=self.fallback)
102
103
        self.get_config_modifications()
104
105
    def run_config_hooks(self, config, config_updates, command_name, logger):
106
        final_cfg_updates = {}
107
        for ch in self.config_hooks:
108
            cfg_upup = ch(deepcopy(config), command_name, logger)
109
            if cfg_upup:
110
                recursive_update(final_cfg_updates, cfg_upup)
111
        recursive_update(final_cfg_updates, config_updates)
112
        return final_cfg_updates
113
114
    def get_config_modifications(self):
115
        self.config_mods = ConfigSummary(
116
            added={key
117
                   for key, value in iterate_flattened(self.config_updates)})
118
        for cfg_summary in self.summaries:
119
            self.config_mods.update_from(cfg_summary)
120
121
    def get_config_updates_recursive(self):
122
        config_updates = self.config_updates.copy()
123
        for sr_path, subrunner in self.subrunners.items():
124
            config_updates[sr_path] = subrunner.get_config_updates_recursive()
125
        return config_updates
126
127
    def get_fixture(self):
128
        if self.fixture is not None:
129
            return self.fixture
130
131
        def get_fixture_recursive(runner):
132
            for sr_path, subrunner in runner.subrunners.items():
133
                # I am not sure if it is necessary to trigger all
134
                subrunner.get_fixture()
135
                get_fixture_recursive(subrunner)
136
                sub_fix = copy(subrunner.config)
137
                sub_path = sr_path
138
                if is_prefix(self.path, sub_path):
139
                    sub_path = sr_path[len(self.path):].strip('.')
140
                # Note: This might fail if we allow non-dict fixtures
141
                set_by_dotted_path(self.fixture, sub_path, sub_fix)
142
143
        self.fixture = copy(self.config)
144
        get_fixture_recursive(self)
145
146
        return self.fixture
147
148
    def finalize_initialization(self, run):
149
        # look at seed again, because it might have changed during the
150
        # configuration process
151
        if 'seed' in self.config:
152
            self.seed = self.config['seed']
153
        self.rnd = create_rnd(self.seed)
154
155
        for cfunc in self._captured_functions:
156
            cfunc.logger = self.logger.getChild(cfunc.__name__)
157
            cfunc.config = get_by_dotted_path(self.get_fixture(), cfunc.prefix,
158
                                              default={})
159
            seed = get_seed(self.rnd)
160
            cfunc.rnd = create_rnd(seed)
161
            cfunc.run = run
162
163
        if not run.force:
164
            self._warn_about_suspicious_changes()
165
166
    def _warn_about_suspicious_changes(self):
167
        for add in sorted(self.config_mods.added):
168
            if not set(iter_prefixes(add)).intersection(self.captured_args):
169
                raise KeyError('Added a new config entry "{}" that is not used'
170
                               ' anywhere'.format(add))
171
            else:
172
                self.logger.warning('Added new config entry: "%s"' % add)
173
174
        for key, (type_old, type_new) in self.config_mods.typechanged.items():
175
            if type_old in (int, float) and type_new in (int, float):
176
                continue
177
            self.logger.warning(
178
                'Changed type of config entry "%s" from %s to %s' %
179
                (key, type_old.__name__, type_new.__name__))
180
181
        for cfg_summary in self.summaries:
182
            for key in cfg_summary.ignored_fallbacks:
183
                self.logger.warning(
184
                    'Ignored attempt to set value of "%s", because it is an '
185
                    'ingredient.' % key
186
                )
187
188
    def __repr__(self):
189
        return "<Scaffold: '{}'>".format(self.path)
190
191
192
def get_configuration(scaffolding):
193
    config = {}
194
    for sc_path, scaffold in reversed(list(scaffolding.items())):
195
        if not scaffold.config:
196
            continue
197
        if sc_path:
198
            set_by_dotted_path(config, sc_path, scaffold.config)
199
        else:
200
            config.update(scaffold.config)
201
    return config
202
203
204
def distribute_named_configs(scaffolding, named_configs):
205
    for ncfg in named_configs:
206
        if os.path.exists(ncfg):
207
            scaffolding[''].use_named_config(ncfg)
208
        else:
209
            path, _, cfg_name = ncfg.rpartition('.')
210
            if path not in scaffolding:
211
                raise KeyError('Ingredient for named config "{}" not found'
212
                               .format(ncfg))
213
            scaffolding[path].use_named_config(cfg_name)
214
215
216
def initialize_logging(experiment, scaffolding, log_level=None):
217
    if experiment.logger is None:
218
        root_logger = create_basic_stream_logger()
219
    else:
220
        root_logger = experiment.logger
221
222
    for sc_path, scaffold in scaffolding.items():
223
        if sc_path:
224
            scaffold.logger = root_logger.getChild(sc_path)
225
        else:
226
            scaffold.logger = root_logger
227
228
    # set log level
229
    if log_level is not None:
230
        try:
231
            lvl = int(log_level)
232
        except ValueError:
233
            lvl = log_level
234
        root_logger.setLevel(lvl)
235
236
    return root_logger, root_logger.getChild(experiment.path)
237
238
239
def create_scaffolding(experiment, sorted_ingredients):
240
    scaffolding = OrderedDict()
241
    for ingredient in sorted_ingredients[:-1]:
242
        scaffolding[ingredient] = Scaffold(
243
            config_scopes=ingredient.configurations,
244
            subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m])
245
                                    for m in ingredient.ingredients]),
246
            path=ingredient.path if ingredient != experiment else '',
247
            captured_functions=ingredient.captured_functions,
248
            commands=ingredient.commands,
249
            named_configs=ingredient.named_configs,
250
            config_hooks=ingredient.config_hooks,
251
            generate_seed=False)
252
253
    scaffolding[experiment] = Scaffold(
254
        experiment.configurations,
255
        subrunners=OrderedDict([(scaffolding[m].path, scaffolding[m])
256
                                for m in experiment.ingredients]),
257
        path=experiment.path if experiment != experiment else '',
258
        captured_functions=experiment.captured_functions,
259
        commands=experiment.commands,
260
        named_configs=experiment.named_configs,
261
        config_hooks=experiment.config_hooks,
262
        generate_seed=True)
263
264
    scaffolding_ret = OrderedDict([
265
        (sc.path, sc)
266
        for sc in scaffolding.values()
267
    ])
268
    if len(scaffolding_ret) != len(scaffolding):
269
        raise ValueError(
270
            'The pathes of the ingredients are not unique. '
271
            '{}'.format([s.path for s in scaffolding])
272
        )
273
274
    return scaffolding_ret
275
276
277
def gather_ingredients_topological(ingredient):
278
    sub_ingredients = defaultdict(int)
279
    for sub_ing, depth in ingredient.traverse_ingredients():
280
        sub_ingredients[sub_ing] = max(sub_ingredients[sub_ing], depth)
281
    return sorted(sub_ingredients, key=lambda x: -sub_ingredients[x])
282
283
284
def get_config_modifications(scaffolding):
285
    config_modifications = ConfigSummary()
286
    for sc_path, scaffold in scaffolding.items():
287
        config_modifications.update_add(scaffold.config_mods, path=sc_path)
288
    return config_modifications
289
290
291
def get_command(scaffolding, command_path):
292
    path, _, command_name = command_path.rpartition('.')
293
    if path not in scaffolding:
294
        raise KeyError('Ingredient for command "%s" not found.' % command_path)
295
296
    if command_name in scaffolding[path].commands:
297
        return scaffolding[path].commands[command_name]
298
    else:
299
        if path:
300
            raise KeyError('Command "%s" not found in ingredient "%s"' %
301
                           (command_name, path))
302
        else:
303
            raise KeyError('Command "%s" not found' % command_name)
304
305
306
def find_best_match(path, prefixes):
307
    """Find the Ingredient that shares the longest prefix with path."""
308
    path_parts = path.split('.')
309
    for p in prefixes:
310
        if len(p) <= len(path_parts) and p == path_parts[:len(p)]:
311
            return '.'.join(p), '.'.join(path_parts[len(p):])
312
    return '', path
313
314
315
def distribute_presets(prefixes, scaffolding, config_updates):
316
    for path, value in iterate_flattened(config_updates):
317
        scaffold_name, suffix = find_best_match(path, prefixes)
318
        scaff = scaffolding[scaffold_name]
319
        set_by_dotted_path(scaff.presets, suffix, value)
320
321
322
def distribute_config_updates(prefixes, scaffolding, config_updates):
323
    for path, value in iterate_flattened(config_updates):
324
        scaffold_name, suffix = find_best_match(path, prefixes)
325
        scaff = scaffolding[scaffold_name]
326
        set_by_dotted_path(scaff.config_updates, suffix, value)
327
328
329
def get_scaffolding_and_config_name(named_config, scaffolding):
330
    if os.path.exists(named_config):
331
        path, cfg_name = '', named_config
332
    else:
333
        path, _, cfg_name = named_config.rpartition('.')
334
335
        if path not in scaffolding:
336
            raise KeyError('Ingredient for named config "{}" not found'
337
                           .format(named_config))
338
    scaff = scaffolding[path]
339
    return scaff, cfg_name
340
341
342
def create_run(experiment, command_name, config_updates=None,
343
               named_configs=(), force=False, log_level=None):
344
345
    sorted_ingredients = gather_ingredients_topological(experiment)
346
    scaffolding = create_scaffolding(experiment, sorted_ingredients)
347
    # get all split non-empty prefixes sorted from deepest to shallowest
348
    prefixes = sorted([s.split('.') for s in scaffolding if s != ''],
349
                      reverse=True, key=lambda p: len(p))
350
351
    # --------- configuration process -------------------
352
353
    # Phase 1: Config updates
354
    config_updates = config_updates or {}
355
    config_updates = convert_to_nested_dict(config_updates)
356
    root_logger, run_logger = initialize_logging(experiment, scaffolding,
357
                                                 log_level)
358
    distribute_config_updates(prefixes, scaffolding, config_updates)
359
360
    # Phase 2: Named Configs
361
    for ncfg in named_configs:
362
        scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding)
363
        scaff.gather_fallbacks()
364
        ncfg_updates = scaff.run_named_config(cfg_name)
365
        distribute_presets(prefixes, scaffolding, ncfg_updates)
366
        for ncfg_key, value in iterate_flattened(ncfg_updates):
367
            set_by_dotted_path(config_updates,
368
                               join_paths(scaff.path, ncfg_key),
369
                               value)
370
371
    distribute_config_updates(prefixes, scaffolding, config_updates)
372
373
    # Phase 3: Normal config scopes
374
    for scaffold in scaffolding.values():
375
        scaffold.gather_fallbacks()
376
        scaffold.set_up_config()
377
378
        # update global config
379
        config = get_configuration(scaffolding)
380
        # run config hooks
381
        config_updates = scaffold.run_config_hooks(config, config_updates,
382
                                                   command_name, run_logger)
383
384
    # Phase 4: finalize seeding
385
    for scaffold in reversed(list(scaffolding.values())):
386
        scaffold.set_up_seed()  # partially recursive
387
388
    config = get_configuration(scaffolding)
389
    config_modifications = get_config_modifications(scaffolding)
390
391
    # ----------------------------------------------------
392
393
    experiment_info = experiment.get_experiment_info()
394
    host_info = get_host_info()
395
    main_function = get_command(scaffolding, command_name)
396
    pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks]
397
    post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks]
398
399
    run = Run(config, config_modifications, main_function,
400
              copy(experiment.observers), root_logger, run_logger,
401
              experiment_info, host_info, pre_runs, post_runs,
402
              experiment.captured_out_filter)
403
404
    if hasattr(main_function, 'unobserved'):
405
        run.unobserved = main_function.unobserved
406
407
    run.force = force
408
409
    for scaffold in scaffolding.values():
410
        scaffold.finalize_initialization(run=run)
411
412
    return run
413