Completed
Push — master ( eca2e7...82fa36 )
by Klaus
36s
created

iterate_flattened()   B

Complexity

Conditions 5

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
dl 0
loc 13
rs 8.5454
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import inspect
7
import logging
8
import os.path
9
import pkgutil
10
import re
11
import shlex
12
import sys
13
import threading
14
import traceback as tb
15
from functools import partial
16
17
import wrapt
18
19
20
__sacred__ = True  # marks files that should be filtered from stack traces
21
22
__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError",
23
           "ObserverError", "SacredInterrupt", "TimeoutInterrupt",
24
           "create_basic_stream_logger", "recursive_update",
25
           "iterate_flattened", "iterate_flattened_separately",
26
           "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
27
           "iter_prefixes", "join_paths", "is_prefix",
28
           "convert_to_nested_dict", "convert_camel_case_to_snake_case",
29
           "print_filtered_stacktrace", "is_subdir",
30
           "optional_kwargs_decorator", "get_inheritors",
31
           "apply_backspaces_and_linefeeds", "StringIO", "FileNotFoundError"]
32
33
# A PY2 compatible basestring, int_types and FileNotFoundError
34
if sys.version_info[0] == 2:
35
    basestring = basestring
36
    int_types = (int, long)
37
38
    import errno
39
40
    class FileNotFoundError(IOError):
41
        def __init__(self, msg):
42
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
43
    from StringIO import StringIO
44
else:
45
    basestring = str
46
    int_types = (int,)
47
48
    # Reassign so that we can import it from here
49
    FileNotFoundError = FileNotFoundError
50
    from io import StringIO
51
52
53
NO_LOGGER = logging.getLogger('ignore')
54
NO_LOGGER.disabled = 1
55
56
PATHCHANGE = object()
57
58
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
59
60
61
class CircularDependencyError(Exception):
62
    """The ingredients of the current experiment form a circular dependency."""
63
64
65
class ObserverError(Exception):
66
    """Error that an observer raises but that should not make the run fail."""
67
68
69
class SacredInterrupt(Exception):
70
    """Base-Class for all custom interrupts.
71
72
    For more information see :ref:`custom_interrupts`.
73
    """
74
75
    STATUS = "INTERRUPTED"
76
77
78
class TimeoutInterrupt(SacredInterrupt):
79
    """Signal a that the experiment timed out.
80
81
    This exception can be used in client code to indicate that the run
82
    exceeded its time limit and has been interrupted because of that.
83
    The status of the interrupted run will then be set to ``TIMEOUT``.
84
85
    For more information see :ref:`custom_interrupts`.
86
    """
87
88
    STATUS = "TIMEOUT"
89
90
91
def create_basic_stream_logger():
92
    logger = logging.getLogger('')
93
    logger.setLevel(logging.INFO)
94
    logger.handlers = []
95
    ch = logging.StreamHandler()
96
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
97
    ch.setFormatter(formatter)
98
    logger.addHandler(ch)
99
    return logger
100
101
102
def recursive_update(d, u):
103
    """
104
    Given two dictionaries d and u, update dict d recursively.
105
106
    E.g.:
107
    d = {'a': {'b' : 1}}
108
    u = {'c': 2, 'a': {'d': 3}}
109
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
110
    """
111
    for k, v in u.items():
112
        if isinstance(v, collections.Mapping):
113
            r = recursive_update(d.get(k, {}), v)
114
            d[k] = r
115
        else:
116
            d[k] = u[k]
117
    return d
118
119
120
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
121
    """
122
    Recursively iterate over the items of a dictionary in a special order.
123
124
    First iterate over manually sorted keys and then over all items that are
125
    non-dictionary values (sorted by keys), then over the rest
126
    (sorted by keys), providing full dotted paths for every leaf.
127
    """
128
    if manually_sorted_keys is None:
129
        manually_sorted_keys = []
130
    for key in manually_sorted_keys:
131
        if key in dictionary:
132
            yield key, dictionary[key]
133
134
    single_line_keys = [key for key in dictionary.keys() if
135
                        key not in manually_sorted_keys and
136
                        (not dictionary[key] or
137
                         not isinstance(dictionary[key], dict))]
138
    for key in sorted(single_line_keys):
139
        yield key, dictionary[key]
140
141
    multi_line_keys = [key for key in dictionary.keys() if
142
                       key not in manually_sorted_keys and
143
                       (dictionary[key] and
144
                        isinstance(dictionary[key], dict))]
145
    for key in sorted(multi_line_keys):
146
        yield key, PATHCHANGE
147
        for k, val in iterate_flattened_separately(dictionary[key],
148
                                                   manually_sorted_keys):
149
            yield join_paths(key, k), val
150
151
152
def iterate_flattened(d):
153
    """
154
    Recursively iterate over the items of a dictionary.
155
156
    Provides a full dotted paths for every leaf.
157
    """
158
    for key in sorted(d.keys()):
159
        value = d[key]
160
        if isinstance(value, dict) and value:
161
            for k, v in iterate_flattened(d[key]):
162
                yield join_paths(key, k), v
163
        else:
164
            yield key, value
165
166
167
def set_by_dotted_path(d, path, value):
168
    """
169
    Set an entry in a nested dict using a dotted path.
170
171
    Will create dictionaries as needed.
172
173
    Examples
174
    --------
175
    >>> d = {'foo': {'bar': 7}}
176
    >>> set_by_dotted_path(d, 'foo.bar', 10)
177
    >>> d
178
    {'foo': {'bar': 10}}
179
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
180
    >>> d
181
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
182
183
    """
184
    split_path = path.split('.')
185
    current_option = d
186
    for p in split_path[:-1]:
187
        if p not in current_option:
188
            current_option[p] = dict()
189
        current_option = current_option[p]
190
    current_option[split_path[-1]] = value
191
192
193
def get_by_dotted_path(d, path, default=None):
194
    """
195
    Get an entry from nested dictionaries using a dotted path.
196
197
    Example:
198
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
199
    12
200
    """
201
    if not path:
202
        return d
203
    split_path = path.split('.')
204
    current_option = d
205
    for p in split_path:
206
        if p not in current_option:
207
            return default
208
        current_option = current_option[p]
209
    return current_option
210
211
212
def iter_path_splits(path):
213
    """
214
    Iterate over possible splits of a dotted path.
215
216
    The first part can be empty the second should not be.
217
218
    Example:
219
    >>> list(iter_path_splits('foo.bar.baz'))
220
    [('',        'foo.bar.baz'),
221
     ('foo',     'bar.baz'),
222
     ('foo.bar', 'baz')]
223
    """
224
    split_path = path.split('.')
225
    for i in range(len(split_path)):
226
        p1 = join_paths(*split_path[:i])
227
        p2 = join_paths(*split_path[i:])
228
        yield p1, p2
229
230
231
def iter_prefixes(path):
232
    """
233
    Iterate through all (non-empty) prefixes of a dotted path.
234
235
    Example:
236
    >>> list(iter_prefixes('foo.bar.baz'))
237
    ['foo', 'foo.bar', 'foo.bar.baz']
238
    """
239
    split_path = path.split('.')
240
    for i in range(1, len(split_path) + 1):
241
        yield join_paths(*split_path[:i])
242
243
244
def join_paths(*parts):
245
    """Join different parts together to a valid dotted path."""
246
    return '.'.join(str(p).strip('.') for p in parts if p)
247
248
249
def is_prefix(pre_path, path):
250
    """Return True if pre_path is a path-prefix of path."""
251
    pre_path = pre_path.strip('.')
252
    path = path.strip('.')
253
    return not pre_path or path.startswith(pre_path + '.')
254
255
256
def convert_to_nested_dict(dotted_dict):
257
    """Convert a dict with dotted path keys to corresponding nested dict."""
258
    nested_dict = {}
259
    for k, v in iterate_flattened(dotted_dict):
260
        set_by_dotted_path(nested_dict, k, v)
261
    return nested_dict
262
263
264
def print_filtered_stacktrace():
265
    exc_type, exc_value, exc_traceback = sys.exc_info()
266
    # determine if last exception is from sacred
267
    current_tb = exc_traceback
268
    while current_tb.tb_next is not None:
269
        current_tb = current_tb.tb_next
270
    if '__sacred__' in current_tb.tb_frame.f_globals:
271
        header = ["Exception originated from within Sacred.\n"
272
                  "Traceback (most recent calls):\n"]
273
        texts = tb.format_exception(exc_type, exc_value, current_tb)
274
        print(''.join(header + texts[1:]).strip(), file=sys.stderr)
275
    else:
276
        if sys.version_info >= (3, 3):
277
            tb_exception =\
278
                tb.TracebackException(exc_type, exc_value, exc_traceback,
279
                                      limit=None)
280
            for line in filtered_traceback_format(tb_exception):
281
                print(line, file=sys.stderr, end="")
282
        else:
283
            print("Traceback (most recent calls WITHOUT Sacred internals):",
284
                  file=sys.stderr)
285
            current_tb = exc_traceback
286
            while current_tb is not None:
287
                if '__sacred__' not in current_tb.tb_frame.f_globals:
288
                    tb.print_tb(current_tb, 1)
289
                current_tb = current_tb.tb_next
290
            print("\n".join(tb.format_exception_only(exc_type,
291
                                                     exc_value)).strip(),
292
                  file=sys.stderr)
293
294
295
def filtered_traceback_format(tb_exception, chain=True):
296
    if chain:
297
        if tb_exception.__cause__ is not None:
298
            for line in filtered_traceback_format(tb_exception.__cause__,
299
                                                  chain=chain):
300
                yield line
301
            yield tb._cause_message
302
        elif (tb_exception.__context__ is not None and
303
              not tb_exception.__suppress_context__):
304
            for line in filtered_traceback_format(tb_exception.__context__,
305
                                                  chain=chain):
306
                yield line
307
            yield tb._context_message
308
    yield 'Traceback (most recent calls WITHOUT Sacred internals):\n'
309
    current_tb = tb_exception.exc_traceback
310
    while current_tb is not None:
311
        if '__sacred__' not in current_tb.tb_frame.f_globals:
312
            stack = tb.StackSummary.extract(tb.walk_tb(current_tb),
313
                                            limit=1,
314
                                            lookup_lines=True,
315
                                            capture_locals=False)
316
            for line in stack.format():
317
                yield line
318
        current_tb = current_tb.tb_next
319
    for line in tb_exception.format_exception_only():
320
        yield line
321
322
323
def is_subdir(path, directory):
324
    path = os.path.abspath(os.path.realpath(path)) + os.sep
325
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
326
327
    return path.startswith(directory)
328
329
330
# noinspection PyUnusedLocal
331
@wrapt.decorator
332
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
333
    # here wrapped is itself a decorator
334
    if args:  # means it was used as a normal decorator (so just call it)
335
        return wrapped(*args, **kwargs)
336
    else:  # used with kwargs, so we need to return a decorator
337
        return partial(wrapped, **kwargs)
338
339
340
def get_inheritors(cls):
341
    """Get a set of all classes that inherit from the given class."""
342
    subclasses = set()
343
    work = [cls]
344
    while work:
345
        parent = work.pop()
346
        for child in parent.__subclasses__():
347
            if child not in subclasses:
348
                subclasses.add(child)
349
                work.append(child)
350
    return subclasses
351
352
353
# Credit to Zarathustra and epost from stackoverflow
354
# Taken from http://stackoverflow.com/a/1176023/1388435
355
def convert_camel_case_to_snake_case(name):
356
    """Convert CamelCase to snake_case."""
357
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
358
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
359
360
361
def apply_backspaces_and_linefeeds(text):
362
    """
363
    Interpret backspaces and linefeeds in text like a terminal would.
364
365
    Interpret text like a terminal by removing backspace and linefeed
366
    characters and applying them line by line.
367
368
    If final line ends with a carriage it keeps it to be concatenable with next
369
    output chunk.
370
    """
371
    orig_lines = text.split('\n')
372
    orig_lines_len = len(orig_lines)
373
    new_lines = []
374
    for orig_line_idx, orig_line in enumerate(orig_lines):
375
        chars, cursor = [], 0
376
        orig_line_len = len(orig_line)
377
        for orig_char_idx, orig_char in enumerate(orig_line):
378
            if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or
379
                                      orig_line_idx != orig_lines_len - 1):
380
                cursor = 0
381
            elif orig_char == '\b':
382
                cursor = max(0, cursor - 1)
383
            else:
384
                if (orig_char == '\r' and
385
                        orig_char_idx == orig_line_len - 1 and
386
                        orig_line_idx == orig_lines_len - 1):
387
                    cursor = len(chars)
388
                if cursor == len(chars):
389
                    chars.append(orig_char)
390
                else:
391
                    chars[cursor] = orig_char
392
                cursor += 1
393
        new_lines.append(''.join(chars))
394
    return '\n'.join(new_lines)
395
396
397
def module_exists(modname):
398
    """Checks if a module exists without actually importing it."""
399
    return pkgutil.find_loader(modname) is not None
400
401
402
def modules_exist(*modnames):
403
    return all(module_exists(m) for m in modnames)
404
405
406
def module_is_in_cache(modname):
407
    """Checks if a module was imported before (is in the import cache)."""
408
    return modname in sys.modules
409
410
411
def module_is_imported(modname, scope=None):
412
    """Checks if a module is imported within the current namespace."""
413
    # return early if modname is not even cached
414
    if not module_is_in_cache(modname):
415
        return False
416
417
    if scope is None:  # use globals() of the caller by default
418
        scope = inspect.stack()[1][0].f_globals
419
420
    for m in scope.values():
421
        if isinstance(m, type(sys)) and m.__name__ == modname:
422
            return True
423
424
    return False
425
426
427
def ensure_wellformed_argv(argv):
428
    if argv is None:
429
        argv = sys.argv
430
    elif isinstance(argv, basestring):
431
        argv = shlex.split(argv)
432
    else:
433
        if not isinstance(argv, (list, tuple)):
434
            raise ValueError("argv must be str or list, but was {}"
435
                             .format(type(argv)))
436
        if not all([isinstance(a, basestring) for a in argv]):
437
            problems = [a for a in argv if not isinstance(a, basestring)]
438
            raise ValueError("argv must be list of str but contained the "
439
                             "following elements: {}".format(problems))
440
    return argv
441
442
443
class IntervalTimer(threading.Thread):
444
    @classmethod
445
    def create(cls, func, interval=10):
446
        stop_event = threading.Event()
447
        timer_thread = cls(stop_event, func, interval)
448
        return stop_event, timer_thread
449
450
    def __init__(self, event, func, interval=10.):
451
        threading.Thread.__init__(self)
452
        self.stopped = event
453
        self.func = func
454
        self.interval = interval
455
456
    def run(self):
457
        while not self.stopped.wait(self.interval):
458
            self.func()
459
        self.func()
460