Completed
Push — master ( cdecc6...9c9404 )
by Klaus
02:35
created

IntervalTimer.create()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 1 Features 0
Metric Value
cc 1
c 1
b 1
f 0
dl 0
loc 5
rs 9.4285
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import inspect
7
import logging
8
import os.path
9
import pkgutil
10
import re
11
import shlex
12
import sys
13
import threading
14
import traceback as tb
15
from functools import partial
16
17
import wrapt
18
19
20
__sacred__ = True  # marks files that should be filtered from stack traces
21
22
__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError",
23
           "ObserverError", "SacredInterrupt", "TimeoutInterrupt",
24
           "create_basic_stream_logger", "recursive_update",
25
           "iterate_flattened", "iterate_flattened_separately",
26
           "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
27
           "iter_prefixes", "join_paths", "is_prefix",
28
           "convert_to_nested_dict", "convert_camel_case_to_snake_case",
29
           "print_filtered_stacktrace", "is_subdir",
30
           "optional_kwargs_decorator", "get_inheritors",
31
           "apply_backspaces_and_linefeeds", "StringIO", "FileNotFoundError"]
32
33
# A PY2 compatible basestring, int_types and FileNotFoundError
34
if sys.version_info[0] == 2:
35
    basestring = basestring
36
    int_types = (int, long)
37
38
    import errno
39
40
    class FileNotFoundError(IOError):
41
        def __init__(self, msg):
42
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
43
    from StringIO import StringIO
44
else:
45
    basestring = str
46
    int_types = (int,)
47
48
    # Reassign so that we can import it from here
49
    FileNotFoundError = FileNotFoundError
50
    from io import StringIO
51
52
53
NO_LOGGER = logging.getLogger('ignore')
54
NO_LOGGER.disabled = 1
55
56
PATHCHANGE = object()
57
58
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
59
60
61
class CircularDependencyError(Exception):
62
    """The ingredients of the current experiment form a circular dependency."""
63
64
65
class ObserverError(Exception):
66
    """Error that an observer raises but that should not make the run fail."""
67
68
69
class SacredInterrupt(Exception):
70
    """Base-Class for all custom interrupts.
71
72
    For more information see :ref:`custom_interrupts`.
73
    """
74
75
    STATUS = "INTERRUPTED"
76
77
78
class TimeoutInterrupt(SacredInterrupt):
79
    """Signal a that the experiment timed out.
80
81
    This exception can be used in client code to indicate that the run
82
    exceeded its time limit and has been interrupted because of that.
83
    The status of the interrupted run will then be set to ``TIMEOUT``.
84
85
    For more information see :ref:`custom_interrupts`.
86
    """
87
88
    STATUS = "TIMEOUT"
89
90
91
def create_basic_stream_logger():
92
    logger = logging.getLogger('')
93
    logger.setLevel(logging.INFO)
94
    logger.handlers = []
95
    ch = logging.StreamHandler()
96
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
97
    ch.setFormatter(formatter)
98
    logger.addHandler(ch)
99
    return logger
100
101
102
def recursive_update(d, u):
103
    """
104
    Given two dictionaries d and u, update dict d recursively.
105
106
    E.g.:
107
    d = {'a': {'b' : 1}}
108
    u = {'c': 2, 'a': {'d': 3}}
109
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
110
    """
111
    for k, v in u.items():
112
        if isinstance(v, collections.Mapping):
113
            r = recursive_update(d.get(k, {}), v)
114
            d[k] = r
115
        else:
116
            d[k] = u[k]
117
    return d
118
119
120
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
121
    """
122
    Recursively iterate over the items of a dictionary in a special order.
123
124
    First iterate over manually sorted keys and then over all items that are
125
    non-dictionary values (sorted by keys), then over the rest
126
    (sorted by keys), providing full dotted paths for every leaf.
127
    """
128
    if manually_sorted_keys is None:
129
        manually_sorted_keys = []
130
    for key in manually_sorted_keys:
131
        if key in dictionary:
132
            yield key, dictionary[key]
133
134
    single_line_keys = [key for key in dictionary.keys() if
135
                        key not in manually_sorted_keys and
136
                        (not dictionary[key] or
137
                         not isinstance(dictionary[key], dict))]
138
    for key in sorted(single_line_keys):
139
        yield key, dictionary[key]
140
141
    multi_line_keys = [key for key in dictionary.keys() if
142
                       key not in manually_sorted_keys and
143
                       (dictionary[key] and
144
                        isinstance(dictionary[key], dict))]
145
    for key in sorted(multi_line_keys):
146
        yield key, PATHCHANGE
147
        for k, val in iterate_flattened_separately(dictionary[key],
148
                                                   manually_sorted_keys):
149
            yield join_paths(key, k), val
150
151
152
def iterate_flattened(d):
153
    """
154
    Recursively iterate over the items of a dictionary.
155
156
    Provides a full dotted paths for every leaf.
157
    """
158
    for key in sorted(d.keys()):
159
        value = d[key]
160
        if isinstance(value, dict):
161
            for k, v in iterate_flattened(d[key]):
162
                yield join_paths(key, k), v
163
        else:
164
            yield key, value
165
166
167
def set_by_dotted_path(d, path, value):
168
    """
169
    Set an entry in a nested dict using a dotted path.
170
171
    Will create dictionaries as needed.
172
173
    Examples
174
    --------
175
    >>> d = {'foo': {'bar': 7}}
176
    >>> set_by_dotted_path(d, 'foo.bar', 10)
177
    >>> d
178
    {'foo': {'bar': 10}}
179
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
180
    >>> d
181
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
182
183
    """
184
    split_path = path.split('.')
185
    current_option = d
186
    for p in split_path[:-1]:
187
        if p not in current_option:
188
            current_option[p] = dict()
189
        current_option = current_option[p]
190
    current_option[split_path[-1]] = value
191
192
193
def get_by_dotted_path(d, path, default=None):
194
    """
195
    Get an entry from nested dictionaries using a dotted path.
196
197
    Example:
198
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
199
    12
200
    """
201
    if not path:
202
        return d
203
    split_path = path.split('.')
204
    current_option = d
205
    for p in split_path:
206
        if p not in current_option:
207
            return default
208
        current_option = current_option[p]
209
    return current_option
210
211
212
def iter_path_splits(path):
213
    """
214
    Iterate over possible splits of a dotted path.
215
216
    The first part can be empty the second should not be.
217
218
    Example:
219
    >>> list(iter_path_splits('foo.bar.baz'))
220
    [('',        'foo.bar.baz'),
221
     ('foo',     'bar.baz'),
222
     ('foo.bar', 'baz')]
223
    """
224
    split_path = path.split('.')
225
    for i in range(len(split_path)):
226
        p1 = join_paths(*split_path[:i])
227
        p2 = join_paths(*split_path[i:])
228
        yield p1, p2
229
230
231
def iter_prefixes(path):
232
    """
233
    Iterate through all (non-empty) prefixes of a dotted path.
234
235
    Example:
236
    >>> list(iter_prefixes('foo.bar.baz'))
237
    ['foo', 'foo.bar', 'foo.bar.baz']
238
    """
239
    split_path = path.split('.')
240
    for i in range(1, len(split_path) + 1):
241
        yield join_paths(*split_path[:i])
242
243
244
def join_paths(*parts):
245
    """Join different parts together to a valid dotted path."""
246
    return '.'.join(str(p).strip('.') for p in parts if p)
247
248
249
def is_prefix(pre_path, path):
250
    """Return True if pre_path is a path-prefix of path."""
251
    pre_path = pre_path.strip('.')
252
    path = path.strip('.')
253
    return not pre_path or path.startswith(pre_path + '.')
254
255
256
def convert_to_nested_dict(dotted_dict):
257
    """Convert a dict with dotted path keys to corresponding nested dict."""
258
    nested_dict = {}
259
    for k, v in iterate_flattened(dotted_dict):
260
        set_by_dotted_path(nested_dict, k, v)
261
    return nested_dict
262
263
264
def print_filtered_stacktrace():
265
    exc_type, exc_value, exc_traceback = sys.exc_info()
266
    # determine if last exception is from sacred
267
    current_tb = exc_traceback
268
    while current_tb.tb_next is not None:
269
        current_tb = current_tb.tb_next
270
    if '__sacred__' in current_tb.tb_frame.f_globals:
271
        print("Exception originated from within Sacred.\n"
272
              "Traceback (most recent calls):", file=sys.stderr)
273
        tb.print_tb(exc_traceback)
274
        tb.print_exception(exc_type, exc_value, None)
275
    else:
276
        print("Traceback (most recent calls WITHOUT Sacred internals):",
277
              file=sys.stderr)
278
        current_tb = exc_traceback
279
        while current_tb is not None:
280
            if '__sacred__' not in current_tb.tb_frame.f_globals:
281
                tb.print_tb(current_tb, 1)
282
            current_tb = current_tb.tb_next
283
        print("\n".join(tb.format_exception_only(exc_type, exc_value)).strip(),
284
              file=sys.stderr)
285
286
287
def is_subdir(path, directory):
288
    path = os.path.abspath(os.path.realpath(path)) + os.sep
289
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
290
291
    return path.startswith(directory)
292
293
294
# noinspection PyUnusedLocal
295
@wrapt.decorator
296
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
297
    # here wrapped is itself a decorator
298
    if args:  # means it was used as a normal decorator (so just call it)
299
        return wrapped(*args, **kwargs)
300
    else:  # used with kwargs, so we need to return a decorator
301
        return partial(wrapped, **kwargs)
302
303
304
def get_inheritors(cls):
305
    """Get a set of all classes that inherit from the given class."""
306
    subclasses = set()
307
    work = [cls]
308
    while work:
309
        parent = work.pop()
310
        for child in parent.__subclasses__():
311
            if child not in subclasses:
312
                subclasses.add(child)
313
                work.append(child)
314
    return subclasses
315
316
317
# Credit to Zarathustra and epost from stackoverflow
318
# Taken from http://stackoverflow.com/a/1176023/1388435
319
def convert_camel_case_to_snake_case(name):
320
    """Convert CamelCase to snake_case."""
321
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
322
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
323
324
325
def apply_backspaces_and_linefeeds(text):
326
    """
327
    Interpret backspaces and linefeeds in text like a terminal would.
328
329
    Interpret text like a terminal by removing backspace and linefeed
330
    characters and applying them line by line.
331
332
    If final line ends with a carriage it keeps it to be concatenable with next
333
    output chunk.
334
    """
335
    orig_lines = text.split('\n')
336
    orig_lines_len = len(orig_lines)
337
    new_lines = []
338
    for orig_line_idx, orig_line in enumerate(orig_lines):
339
        chars, cursor = [], 0
340
        orig_line_len = len(orig_line)
341
        for orig_char_idx, orig_char in enumerate(orig_line):
342
            if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or
343
                                      orig_line_idx != orig_lines_len - 1):
344
                cursor = 0
345
            elif orig_char == '\b':
346
                cursor = max(0, cursor - 1)
347
            else:
348
                if (orig_char == '\r' and
349
                        orig_char_idx == orig_line_len - 1 and
350
                        orig_line_idx == orig_lines_len - 1):
351
                    cursor = len(chars)
352
                if cursor == len(chars):
353
                    chars.append(orig_char)
354
                else:
355
                    chars[cursor] = orig_char
356
                cursor += 1
357
        new_lines.append(''.join(chars))
358
    return '\n'.join(new_lines)
359
360
361
def module_exists(modname):
362
    """Checks if a module exists without actually importing it."""
363
    return pkgutil.find_loader(modname) is not None
364
365
366
def modules_exist(*modnames):
367
    return all(module_exists(m) for m in modnames)
368
369
370
def module_is_in_cache(modname):
371
    """Checks if a module was imported before (is in the import cache)."""
372
    return modname in sys.modules
373
374
375
def module_is_imported(modname, scope=None):
376
    """Checks if a module is imported within the current namespace."""
377
    # return early if modname is not even cached
378
    if not module_is_in_cache(modname):
379
        return False
380
381
    if scope is None:  # use globals() of the caller by default
382
        scope = inspect.stack()[1][0].f_globals
383
384
    for m in scope.values():
385
        if isinstance(m, type(sys)) and m.__name__ == modname:
386
            return True
387
388
    return False
389
390
391
def ensure_wellformed_argv(argv):
392
    if argv is None:
393
        argv = sys.argv
394
    elif isinstance(argv, basestring):
395
        argv = shlex.split(argv)
396
    else:
397
        if not isinstance(argv, (list, tuple)):
398
            raise ValueError("argv must be str or list, but was {}"
399
                             .format(type(argv)))
400
        if not all([isinstance(a, basestring) for a in argv]):
401
            problems = [a for a in argv if not isinstance(a, basestring)]
402
            raise ValueError("argv must be list of str but contained the "
403
                             "following elements: {}".format(problems))
404
    return argv
405
406
407
class IntervalTimer(threading.Thread):
408
    @classmethod
409
    def create(cls, func, interval=10):
410
        stop_event = threading.Event()
411
        timer_thread = cls(stop_event, func, interval)
412
        return stop_event, timer_thread
413
414
    def __init__(self, event, func, interval=10.):
415
        threading.Thread.__init__(self)
416
        self.stopped = event
417
        self.func = func
418
        self.interval = interval
419
420
    def run(self):
421
        while not self.stopped.wait(self.interval):
422
            self.func()
423
        self.func()
424