Completed
Push — master ( 6c3661...416f46 )
by Klaus
34s
created

_is_sacred_frame()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 2
rs 10
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import inspect
7
import logging
8
import os.path
9
import pkgutil
10
import re
11
import shlex
12
import sys
13
import threading
14
import traceback as tb
15
from functools import partial
16
17
import wrapt
18
19
20
__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError",
21
           "ObserverError", "SacredInterrupt", "TimeoutInterrupt",
22
           "create_basic_stream_logger", "recursive_update",
23
           "iterate_flattened", "iterate_flattened_separately",
24
           "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
25
           "iter_prefixes", "join_paths", "is_prefix",
26
           "convert_to_nested_dict", "convert_camel_case_to_snake_case",
27
           "print_filtered_stacktrace", "is_subdir",
28
           "optional_kwargs_decorator", "get_inheritors",
29
           "apply_backspaces_and_linefeeds", "StringIO", "FileNotFoundError"]
30
31
# A PY2 compatible basestring, int_types and FileNotFoundError
32
if sys.version_info[0] == 2:
33
    basestring = basestring
34
    int_types = (int, long)
35
36
    import errno
37
38
    class FileNotFoundError(IOError):
39
        def __init__(self, msg):
40
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
41
    from StringIO import StringIO
42
else:
43
    basestring = str
44
    int_types = (int,)
45
46
    # Reassign so that we can import it from here
47
    FileNotFoundError = FileNotFoundError
48
    from io import StringIO
49
50
51
NO_LOGGER = logging.getLogger('ignore')
52
NO_LOGGER.disabled = 1
53
54
PATHCHANGE = object()
55
56
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
57
58
59
class CircularDependencyError(Exception):
60
    """The ingredients of the current experiment form a circular dependency."""
61
62
63
class ObserverError(Exception):
64
    """Error that an observer raises but that should not make the run fail."""
65
66
67
class SacredInterrupt(Exception):
68
    """Base-Class for all custom interrupts.
69
70
    For more information see :ref:`custom_interrupts`.
71
    """
72
73
    STATUS = "INTERRUPTED"
74
75
76
class TimeoutInterrupt(SacredInterrupt):
77
    """Signal a that the experiment timed out.
78
79
    This exception can be used in client code to indicate that the run
80
    exceeded its time limit and has been interrupted because of that.
81
    The status of the interrupted run will then be set to ``TIMEOUT``.
82
83
    For more information see :ref:`custom_interrupts`.
84
    """
85
86
    STATUS = "TIMEOUT"
87
88
89
def create_basic_stream_logger():
90
    logger = logging.getLogger('')
91
    logger.setLevel(logging.INFO)
92
    logger.handlers = []
93
    ch = logging.StreamHandler()
94
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
95
    ch.setFormatter(formatter)
96
    logger.addHandler(ch)
97
    return logger
98
99
100
def recursive_update(d, u):
101
    """
102
    Given two dictionaries d and u, update dict d recursively.
103
104
    E.g.:
105
    d = {'a': {'b' : 1}}
106
    u = {'c': 2, 'a': {'d': 3}}
107
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
108
    """
109
    for k, v in u.items():
110
        if isinstance(v, collections.Mapping):
111
            r = recursive_update(d.get(k, {}), v)
112
            d[k] = r
113
        else:
114
            d[k] = u[k]
115
    return d
116
117
118
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
119
    """
120
    Recursively iterate over the items of a dictionary in a special order.
121
122
    First iterate over manually sorted keys and then over all items that are
123
    non-dictionary values (sorted by keys), then over the rest
124
    (sorted by keys), providing full dotted paths for every leaf.
125
    """
126
    if manually_sorted_keys is None:
127
        manually_sorted_keys = []
128
    for key in manually_sorted_keys:
129
        if key in dictionary:
130
            yield key, dictionary[key]
131
132
    single_line_keys = [key for key in dictionary.keys() if
133
                        key not in manually_sorted_keys and
134
                        (not dictionary[key] or
135
                         not isinstance(dictionary[key], dict))]
136
    for key in sorted(single_line_keys):
137
        yield key, dictionary[key]
138
139
    multi_line_keys = [key for key in dictionary.keys() if
140
                       key not in manually_sorted_keys and
141
                       (dictionary[key] and
142
                        isinstance(dictionary[key], dict))]
143
    for key in sorted(multi_line_keys):
144
        yield key, PATHCHANGE
145
        for k, val in iterate_flattened_separately(dictionary[key],
146
                                                   manually_sorted_keys):
147
            yield join_paths(key, k), val
148
149
150
def iterate_flattened(d):
151
    """
152
    Recursively iterate over the items of a dictionary.
153
154
    Provides a full dotted paths for every leaf.
155
    """
156
    for key in sorted(d.keys()):
157
        value = d[key]
158
        if isinstance(value, dict) and value:
159
            for k, v in iterate_flattened(d[key]):
160
                yield join_paths(key, k), v
161
        else:
162
            yield key, value
163
164
165
def set_by_dotted_path(d, path, value):
166
    """
167
    Set an entry in a nested dict using a dotted path.
168
169
    Will create dictionaries as needed.
170
171
    Examples
172
    --------
173
    >>> d = {'foo': {'bar': 7}}
174
    >>> set_by_dotted_path(d, 'foo.bar', 10)
175
    >>> d
176
    {'foo': {'bar': 10}}
177
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
178
    >>> d
179
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
180
181
    """
182
    split_path = path.split('.')
183
    current_option = d
184
    for p in split_path[:-1]:
185
        if p not in current_option:
186
            current_option[p] = dict()
187
        current_option = current_option[p]
188
    current_option[split_path[-1]] = value
189
190
191
def get_by_dotted_path(d, path, default=None):
192
    """
193
    Get an entry from nested dictionaries using a dotted path.
194
195
    Example:
196
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
197
    12
198
    """
199
    if not path:
200
        return d
201
    split_path = path.split('.')
202
    current_option = d
203
    for p in split_path:
204
        if p not in current_option:
205
            return default
206
        current_option = current_option[p]
207
    return current_option
208
209
210
def iter_path_splits(path):
211
    """
212
    Iterate over possible splits of a dotted path.
213
214
    The first part can be empty the second should not be.
215
216
    Example:
217
    >>> list(iter_path_splits('foo.bar.baz'))
218
    [('',        'foo.bar.baz'),
219
     ('foo',     'bar.baz'),
220
     ('foo.bar', 'baz')]
221
    """
222
    split_path = path.split('.')
223
    for i in range(len(split_path)):
224
        p1 = join_paths(*split_path[:i])
225
        p2 = join_paths(*split_path[i:])
226
        yield p1, p2
227
228
229
def iter_prefixes(path):
230
    """
231
    Iterate through all (non-empty) prefixes of a dotted path.
232
233
    Example:
234
    >>> list(iter_prefixes('foo.bar.baz'))
235
    ['foo', 'foo.bar', 'foo.bar.baz']
236
    """
237
    split_path = path.split('.')
238
    for i in range(1, len(split_path) + 1):
239
        yield join_paths(*split_path[:i])
240
241
242
def join_paths(*parts):
243
    """Join different parts together to a valid dotted path."""
244
    return '.'.join(str(p).strip('.') for p in parts if p)
245
246
247
def is_prefix(pre_path, path):
248
    """Return True if pre_path is a path-prefix of path."""
249
    pre_path = pre_path.strip('.')
250
    path = path.strip('.')
251
    return not pre_path or path.startswith(pre_path + '.')
252
253
254
def convert_to_nested_dict(dotted_dict):
255
    """Convert a dict with dotted path keys to corresponding nested dict."""
256
    nested_dict = {}
257
    for k, v in iterate_flattened(dotted_dict):
258
        set_by_dotted_path(nested_dict, k, v)
259
    return nested_dict
260
261
262
def _is_sacred_frame(frame):
263
    return frame.f_globals["__name__"].split('.')[0] == 'sacred'
264
265
266
def print_filtered_stacktrace():
267
    exc_type, exc_value, exc_traceback = sys.exc_info()
268
    # determine if last exception is from sacred
269
    current_tb = exc_traceback
270
    while current_tb.tb_next is not None:
271
        current_tb = current_tb.tb_next
272
    if _is_sacred_frame(current_tb.tb_frame):
273
        header = ["Exception originated from within Sacred.\n"
274
                  "Traceback (most recent calls):\n"]
275
        texts = tb.format_exception(exc_type, exc_value, current_tb)
276
        print(''.join(header + texts[1:]).strip(), file=sys.stderr)
277
    else:
278
        if sys.version_info >= (3, 3):
279
            tb_exception =\
280
                tb.TracebackException(exc_type, exc_value, exc_traceback,
281
                                      limit=None)
282
            for line in filtered_traceback_format(tb_exception):
283
                print(line, file=sys.stderr, end="")
284
        else:
285
            print("Traceback (most recent calls WITHOUT Sacred internals):",
286
                  file=sys.stderr)
287
            current_tb = exc_traceback
288
            while current_tb is not None:
289
                if not _is_sacred_frame(current_tb.tb_frame):
290
                    tb.print_tb(current_tb, 1)
291
                current_tb = current_tb.tb_next
292
            print("\n".join(tb.format_exception_only(exc_type,
293
                                                     exc_value)).strip(),
294
                  file=sys.stderr)
295
296
297
def filtered_traceback_format(tb_exception, chain=True):
298
    if chain:
299
        if tb_exception.__cause__ is not None:
300
            for line in filtered_traceback_format(tb_exception.__cause__,
301
                                                  chain=chain):
302
                yield line
303
            yield tb._cause_message
304
        elif (tb_exception.__context__ is not None and
305
              not tb_exception.__suppress_context__):
306
            for line in filtered_traceback_format(tb_exception.__context__,
307
                                                  chain=chain):
308
                yield line
309
            yield tb._context_message
310
    yield 'Traceback (most recent calls WITHOUT Sacred internals):\n'
311
    current_tb = tb_exception.exc_traceback
312
    while current_tb is not None:
313
        if not _is_sacred_frame(current_tb.tb_frame):
314
            stack = tb.StackSummary.extract(tb.walk_tb(current_tb),
315
                                            limit=1,
316
                                            lookup_lines=True,
317
                                            capture_locals=False)
318
            for line in stack.format():
319
                yield line
320
        current_tb = current_tb.tb_next
321
    for line in tb_exception.format_exception_only():
322
        yield line
323
324
325
def is_subdir(path, directory):
326
    path = os.path.abspath(os.path.realpath(path)) + os.sep
327
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
328
329
    return path.startswith(directory)
330
331
332
# noinspection PyUnusedLocal
333
@wrapt.decorator
334
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
335
    # here wrapped is itself a decorator
336
    if args:  # means it was used as a normal decorator (so just call it)
337
        return wrapped(*args, **kwargs)
338
    else:  # used with kwargs, so we need to return a decorator
339
        return partial(wrapped, **kwargs)
340
341
342
def get_inheritors(cls):
343
    """Get a set of all classes that inherit from the given class."""
344
    subclasses = set()
345
    work = [cls]
346
    while work:
347
        parent = work.pop()
348
        for child in parent.__subclasses__():
349
            if child not in subclasses:
350
                subclasses.add(child)
351
                work.append(child)
352
    return subclasses
353
354
355
# Credit to Zarathustra and epost from stackoverflow
356
# Taken from http://stackoverflow.com/a/1176023/1388435
357
def convert_camel_case_to_snake_case(name):
358
    """Convert CamelCase to snake_case."""
359
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
360
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
361
362
363
def apply_backspaces_and_linefeeds(text):
364
    """
365
    Interpret backspaces and linefeeds in text like a terminal would.
366
367
    Interpret text like a terminal by removing backspace and linefeed
368
    characters and applying them line by line.
369
370
    If final line ends with a carriage it keeps it to be concatenable with next
371
    output chunk.
372
    """
373
    orig_lines = text.split('\n')
374
    orig_lines_len = len(orig_lines)
375
    new_lines = []
376
    for orig_line_idx, orig_line in enumerate(orig_lines):
377
        chars, cursor = [], 0
378
        orig_line_len = len(orig_line)
379
        for orig_char_idx, orig_char in enumerate(orig_line):
380
            if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or
381
                                      orig_line_idx != orig_lines_len - 1):
382
                cursor = 0
383
            elif orig_char == '\b':
384
                cursor = max(0, cursor - 1)
385
            else:
386
                if (orig_char == '\r' and
387
                        orig_char_idx == orig_line_len - 1 and
388
                        orig_line_idx == orig_lines_len - 1):
389
                    cursor = len(chars)
390
                if cursor == len(chars):
391
                    chars.append(orig_char)
392
                else:
393
                    chars[cursor] = orig_char
394
                cursor += 1
395
        new_lines.append(''.join(chars))
396
    return '\n'.join(new_lines)
397
398
399
def module_exists(modname):
400
    """Checks if a module exists without actually importing it."""
401
    return pkgutil.find_loader(modname) is not None
402
403
404
def modules_exist(*modnames):
405
    return all(module_exists(m) for m in modnames)
406
407
408
def module_is_in_cache(modname):
409
    """Checks if a module was imported before (is in the import cache)."""
410
    return modname in sys.modules
411
412
413
def module_is_imported(modname, scope=None):
414
    """Checks if a module is imported within the current namespace."""
415
    # return early if modname is not even cached
416
    if not module_is_in_cache(modname):
417
        return False
418
419
    if scope is None:  # use globals() of the caller by default
420
        scope = inspect.stack()[1][0].f_globals
421
422
    for m in scope.values():
423
        if isinstance(m, type(sys)) and m.__name__ == modname:
424
            return True
425
426
    return False
427
428
429
def ensure_wellformed_argv(argv):
430
    if argv is None:
431
        argv = sys.argv
432
    elif isinstance(argv, basestring):
433
        argv = shlex.split(argv)
434
    else:
435
        if not isinstance(argv, (list, tuple)):
436
            raise ValueError("argv must be str or list, but was {}"
437
                             .format(type(argv)))
438
        if not all([isinstance(a, basestring) for a in argv]):
439
            problems = [a for a in argv if not isinstance(a, basestring)]
440
            raise ValueError("argv must be list of str but contained the "
441
                             "following elements: {}".format(problems))
442
    return argv
443
444
445
class IntervalTimer(threading.Thread):
446
    @classmethod
447
    def create(cls, func, interval=10):
448
        stop_event = threading.Event()
449
        timer_thread = cls(stop_event, func, interval)
450
        return stop_event, timer_thread
451
452
    def __init__(self, event, func, interval=10.):
453
        threading.Thread.__init__(self)
454
        self.stopped = event
455
        self.func = func
456
        self.interval = interval
457
458
    def run(self):
459
        while not self.stopped.wait(self.interval):
460
            self.func()
461
        self.func()
462