Completed
Push — master ( 5d8e11...d66198 )
by Klaus
03:33
created

rel_path()   A

Complexity

Conditions 3

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
dl 0
loc 6
rs 9.4285
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import inspect
7
import logging
8
import os.path
9
import pkgutil
10
import re
11
import shlex
12
import sys
13
import threading
14
import traceback as tb
15
from functools import partial
16
17
import wrapt
18
19
20
__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError",
21
           "ObserverError", "SacredInterrupt", "TimeoutInterrupt",
22
           "create_basic_stream_logger", "recursive_update",
23
           "iterate_flattened", "iterate_flattened_separately",
24
           "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
25
           "iter_prefixes", "join_paths", "is_prefix",
26
           "convert_to_nested_dict", "convert_camel_case_to_snake_case",
27
           "print_filtered_stacktrace", "is_subdir",
28
           "optional_kwargs_decorator", "get_inheritors",
29
           "apply_backspaces_and_linefeeds", "StringIO", "FileNotFoundError",
30
           "rel_path"]
31
32
# A PY2 compatible basestring, int_types and FileNotFoundError
33
if sys.version_info[0] == 2:
34
    basestring = basestring
35
    int_types = (int, long)
36
37
    import errno
38
39
    class FileNotFoundError(IOError):
40
        def __init__(self, msg):
41
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
42
    from StringIO import StringIO
43
else:
44
    basestring = str
45
    int_types = (int,)
46
47
    # Reassign so that we can import it from here
48
    FileNotFoundError = FileNotFoundError
49
    from io import StringIO
50
51
52
NO_LOGGER = logging.getLogger('ignore')
53
NO_LOGGER.disabled = 1
54
55
PATHCHANGE = object()
56
57
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
58
59
60
class CircularDependencyError(Exception):
61
    """The ingredients of the current experiment form a circular dependency."""
62
63
64
class ObserverError(Exception):
65
    """Error that an observer raises but that should not make the run fail."""
66
67
68
class SacredInterrupt(Exception):
69
    """Base-Class for all custom interrupts.
70
71
    For more information see :ref:`custom_interrupts`.
72
    """
73
74
    STATUS = "INTERRUPTED"
75
76
77
class TimeoutInterrupt(SacredInterrupt):
78
    """Signal a that the experiment timed out.
79
80
    This exception can be used in client code to indicate that the run
81
    exceeded its time limit and has been interrupted because of that.
82
    The status of the interrupted run will then be set to ``TIMEOUT``.
83
84
    For more information see :ref:`custom_interrupts`.
85
    """
86
87
    STATUS = "TIMEOUT"
88
89
90
def create_basic_stream_logger():
91
    logger = logging.getLogger('')
92
    logger.setLevel(logging.INFO)
93
    logger.handlers = []
94
    ch = logging.StreamHandler()
95
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
96
    ch.setFormatter(formatter)
97
    logger.addHandler(ch)
98
    return logger
99
100
101
def recursive_update(d, u):
102
    """
103
    Given two dictionaries d and u, update dict d recursively.
104
105
    E.g.:
106
    d = {'a': {'b' : 1}}
107
    u = {'c': 2, 'a': {'d': 3}}
108
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
109
    """
110
    for k, v in u.items():
111
        if isinstance(v, collections.Mapping):
112
            r = recursive_update(d.get(k, {}), v)
113
            d[k] = r
114
        else:
115
            d[k] = u[k]
116
    return d
117
118
119
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
120
    """
121
    Recursively iterate over the items of a dictionary in a special order.
122
123
    First iterate over manually sorted keys and then over all items that are
124
    non-dictionary values (sorted by keys), then over the rest
125
    (sorted by keys), providing full dotted paths for every leaf.
126
    """
127
    if manually_sorted_keys is None:
128
        manually_sorted_keys = []
129
    for key in manually_sorted_keys:
130
        if key in dictionary:
131
            yield key, dictionary[key]
132
133
    single_line_keys = [key for key in dictionary.keys() if
134
                        key not in manually_sorted_keys and
135
                        (not dictionary[key] or
136
                         not isinstance(dictionary[key], dict))]
137
    for key in sorted(single_line_keys):
138
        yield key, dictionary[key]
139
140
    multi_line_keys = [key for key in dictionary.keys() if
141
                       key not in manually_sorted_keys and
142
                       (dictionary[key] and
143
                        isinstance(dictionary[key], dict))]
144
    for key in sorted(multi_line_keys):
145
        yield key, PATHCHANGE
146
        for k, val in iterate_flattened_separately(dictionary[key],
147
                                                   manually_sorted_keys):
148
            yield join_paths(key, k), val
149
150
151
def iterate_flattened(d):
152
    """
153
    Recursively iterate over the items of a dictionary.
154
155
    Provides a full dotted paths for every leaf.
156
    """
157
    for key in sorted(d.keys()):
158
        value = d[key]
159
        if isinstance(value, dict) and value:
160
            for k, v in iterate_flattened(d[key]):
161
                yield join_paths(key, k), v
162
        else:
163
            yield key, value
164
165
166
def set_by_dotted_path(d, path, value):
167
    """
168
    Set an entry in a nested dict using a dotted path.
169
170
    Will create dictionaries as needed.
171
172
    Examples
173
    --------
174
    >>> d = {'foo': {'bar': 7}}
175
    >>> set_by_dotted_path(d, 'foo.bar', 10)
176
    >>> d
177
    {'foo': {'bar': 10}}
178
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
179
    >>> d
180
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
181
182
    """
183
    split_path = path.split('.')
184
    current_option = d
185
    for p in split_path[:-1]:
186
        if p not in current_option:
187
            current_option[p] = dict()
188
        current_option = current_option[p]
189
    current_option[split_path[-1]] = value
190
191
192
def get_by_dotted_path(d, path, default=None):
193
    """
194
    Get an entry from nested dictionaries using a dotted path.
195
196
    Example:
197
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
198
    12
199
    """
200
    if not path:
201
        return d
202
    split_path = path.split('.')
203
    current_option = d
204
    for p in split_path:
205
        if p not in current_option:
206
            return default
207
        current_option = current_option[p]
208
    return current_option
209
210
211
def iter_path_splits(path):
212
    """
213
    Iterate over possible splits of a dotted path.
214
215
    The first part can be empty the second should not be.
216
217
    Example:
218
    >>> list(iter_path_splits('foo.bar.baz'))
219
    [('',        'foo.bar.baz'),
220
     ('foo',     'bar.baz'),
221
     ('foo.bar', 'baz')]
222
    """
223
    split_path = path.split('.')
224
    for i in range(len(split_path)):
225
        p1 = join_paths(*split_path[:i])
226
        p2 = join_paths(*split_path[i:])
227
        yield p1, p2
228
229
230
def iter_prefixes(path):
231
    """
232
    Iterate through all (non-empty) prefixes of a dotted path.
233
234
    Example:
235
    >>> list(iter_prefixes('foo.bar.baz'))
236
    ['foo', 'foo.bar', 'foo.bar.baz']
237
    """
238
    split_path = path.split('.')
239
    for i in range(1, len(split_path) + 1):
240
        yield join_paths(*split_path[:i])
241
242
243
def join_paths(*parts):
244
    """Join different parts together to a valid dotted path."""
245
    return '.'.join(str(p).strip('.') for p in parts if p)
246
247
248
def is_prefix(pre_path, path):
249
    """Return True if pre_path is a path-prefix of path."""
250
    pre_path = pre_path.strip('.')
251
    path = path.strip('.')
252
    return not pre_path or path.startswith(pre_path + '.')
253
254
255
def rel_path(base, path):
256
    """Return path relative to base."""
257
    if base == path:
258
        return ''
259
    assert is_prefix(base, path), "{} not a prefix of {}".format(base, path)
260
    return path[len(base):].strip('.')
261
262
263
def convert_to_nested_dict(dotted_dict):
264
    """Convert a dict with dotted path keys to corresponding nested dict."""
265
    nested_dict = {}
266
    for k, v in iterate_flattened(dotted_dict):
267
        set_by_dotted_path(nested_dict, k, v)
268
    return nested_dict
269
270
271
def _is_sacred_frame(frame):
272
    return frame.f_globals["__name__"].split('.')[0] == 'sacred'
273
274
275
def print_filtered_stacktrace():
276
    exc_type, exc_value, exc_traceback = sys.exc_info()
277
    # determine if last exception is from sacred
278
    current_tb = exc_traceback
279
    while current_tb.tb_next is not None:
280
        current_tb = current_tb.tb_next
281
    if _is_sacred_frame(current_tb.tb_frame):
282
        header = ["Exception originated from within Sacred.\n"
283
                  "Traceback (most recent calls):\n"]
284
        texts = tb.format_exception(exc_type, exc_value, current_tb)
285
        print(''.join(header + texts[1:]).strip(), file=sys.stderr)
286
    else:
287
        if sys.version_info >= (3, 3):
288
            tb_exception =\
289
                tb.TracebackException(exc_type, exc_value, exc_traceback,
290
                                      limit=None)
291
            for line in filtered_traceback_format(tb_exception):
292
                print(line, file=sys.stderr, end="")
293
        else:
294
            print("Traceback (most recent calls WITHOUT Sacred internals):",
295
                  file=sys.stderr)
296
            current_tb = exc_traceback
297
            while current_tb is not None:
298
                if not _is_sacred_frame(current_tb.tb_frame):
299
                    tb.print_tb(current_tb, 1)
300
                current_tb = current_tb.tb_next
301
            print("\n".join(tb.format_exception_only(exc_type,
302
                                                     exc_value)).strip(),
303
                  file=sys.stderr)
304
305
306
def filtered_traceback_format(tb_exception, chain=True):
307
    if chain:
308
        if tb_exception.__cause__ is not None:
309
            for line in filtered_traceback_format(tb_exception.__cause__,
310
                                                  chain=chain):
311
                yield line
312
            yield tb._cause_message
313
        elif (tb_exception.__context__ is not None and
314
              not tb_exception.__suppress_context__):
315
            for line in filtered_traceback_format(tb_exception.__context__,
316
                                                  chain=chain):
317
                yield line
318
            yield tb._context_message
319
    yield 'Traceback (most recent calls WITHOUT Sacred internals):\n'
320
    current_tb = tb_exception.exc_traceback
321
    while current_tb is not None:
322
        if not _is_sacred_frame(current_tb.tb_frame):
323
            stack = tb.StackSummary.extract(tb.walk_tb(current_tb),
324
                                            limit=1,
325
                                            lookup_lines=True,
326
                                            capture_locals=False)
327
            for line in stack.format():
328
                yield line
329
        current_tb = current_tb.tb_next
330
    for line in tb_exception.format_exception_only():
331
        yield line
332
333
334
def is_subdir(path, directory):
335
    path = os.path.abspath(os.path.realpath(path)) + os.sep
336
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
337
338
    return path.startswith(directory)
339
340
341
# noinspection PyUnusedLocal
342
@wrapt.decorator
343
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
344
    # here wrapped is itself a decorator
345
    if args:  # means it was used as a normal decorator (so just call it)
346
        return wrapped(*args, **kwargs)
347
    else:  # used with kwargs, so we need to return a decorator
348
        return partial(wrapped, **kwargs)
349
350
351
def get_inheritors(cls):
352
    """Get a set of all classes that inherit from the given class."""
353
    subclasses = set()
354
    work = [cls]
355
    while work:
356
        parent = work.pop()
357
        for child in parent.__subclasses__():
358
            if child not in subclasses:
359
                subclasses.add(child)
360
                work.append(child)
361
    return subclasses
362
363
364
# Credit to Zarathustra and epost from stackoverflow
365
# Taken from http://stackoverflow.com/a/1176023/1388435
366
def convert_camel_case_to_snake_case(name):
367
    """Convert CamelCase to snake_case."""
368
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
369
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
370
371
372
def apply_backspaces_and_linefeeds(text):
373
    """
374
    Interpret backspaces and linefeeds in text like a terminal would.
375
376
    Interpret text like a terminal by removing backspace and linefeed
377
    characters and applying them line by line.
378
379
    If final line ends with a carriage it keeps it to be concatenable with next
380
    output chunk.
381
    """
382
    orig_lines = text.split('\n')
383
    orig_lines_len = len(orig_lines)
384
    new_lines = []
385
    for orig_line_idx, orig_line in enumerate(orig_lines):
386
        chars, cursor = [], 0
387
        orig_line_len = len(orig_line)
388
        for orig_char_idx, orig_char in enumerate(orig_line):
389
            if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or
390
                                      orig_line_idx != orig_lines_len - 1):
391
                cursor = 0
392
            elif orig_char == '\b':
393
                cursor = max(0, cursor - 1)
394
            else:
395
                if (orig_char == '\r' and
396
                        orig_char_idx == orig_line_len - 1 and
397
                        orig_line_idx == orig_lines_len - 1):
398
                    cursor = len(chars)
399
                if cursor == len(chars):
400
                    chars.append(orig_char)
401
                else:
402
                    chars[cursor] = orig_char
403
                cursor += 1
404
        new_lines.append(''.join(chars))
405
    return '\n'.join(new_lines)
406
407
408
def module_exists(modname):
409
    """Checks if a module exists without actually importing it."""
410
    return pkgutil.find_loader(modname) is not None
411
412
413
def modules_exist(*modnames):
414
    return all(module_exists(m) for m in modnames)
415
416
417
def module_is_in_cache(modname):
418
    """Checks if a module was imported before (is in the import cache)."""
419
    return modname in sys.modules
420
421
422
def module_is_imported(modname, scope=None):
423
    """Checks if a module is imported within the current namespace."""
424
    # return early if modname is not even cached
425
    if not module_is_in_cache(modname):
426
        return False
427
428
    if scope is None:  # use globals() of the caller by default
429
        scope = inspect.stack()[1][0].f_globals
430
431
    for m in scope.values():
432
        if isinstance(m, type(sys)) and m.__name__ == modname:
433
            return True
434
435
    return False
436
437
438
def ensure_wellformed_argv(argv):
439
    if argv is None:
440
        argv = sys.argv
441
    elif isinstance(argv, basestring):
442
        argv = shlex.split(argv)
443
    else:
444
        if not isinstance(argv, (list, tuple)):
445
            raise ValueError("argv must be str or list, but was {}"
446
                             .format(type(argv)))
447
        if not all([isinstance(a, basestring) for a in argv]):
448
            problems = [a for a in argv if not isinstance(a, basestring)]
449
            raise ValueError("argv must be list of str but contained the "
450
                             "following elements: {}".format(problems))
451
    return argv
452
453
454
class IntervalTimer(threading.Thread):
455
    @classmethod
456
    def create(cls, func, interval=10):
457
        stop_event = threading.Event()
458
        timer_thread = cls(stop_event, func, interval)
459
        return stop_event, timer_thread
460
461
    def __init__(self, event, func, interval=10.):
462
        threading.Thread.__init__(self)
463
        self.stopped = event
464
        self.func = func
465
        self.interval = interval
466
467
    def run(self):
468
        while not self.stopped.wait(self.interval):
469
            self.func()
470
        self.func()
471