Completed
Push — master ( dbc3e8...c58680 )
by Klaus
37s
created

module_is_imported()   B

Complexity

Conditions 6

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 14
rs 8
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import inspect
7
import logging
8
import os.path
9
import pkgutil
10
import re
11
import sys
12
import traceback as tb
13
from functools import partial
14
15
import wrapt
16
17
18
__sacred__ = True  # marks files that should be filtered from stack traces
19
20
__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError",
21
           "ObserverError", "SacredInterrupt", "TimeoutInterrupt",
22
           "create_basic_stream_logger", "recursive_update",
23
           "iterate_flattened", "iterate_flattened_separately",
24
           "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
25
           "iter_prefixes", "join_paths", "is_prefix",
26
           "convert_to_nested_dict", "convert_camel_case_to_snake_case",
27
           "print_filtered_stacktrace", "is_subdir",
28
           "optional_kwargs_decorator", "get_inheritors",
29
           "apply_backspaces_and_linefeeds", "StringIO", "FileNotFoundError"]
30
31
# A PY2 compatible FileNotFoundError
32
if sys.version_info[0] == 2:
33
    import errno
34
35
    class FileNotFoundError(IOError):
36
        def __init__(self, msg):
37
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
38
    from StringIO import StringIO
39
else:
40
    # Reassign so that we can import it from here
41
    FileNotFoundError = FileNotFoundError
42
    from io import StringIO
43
44
NO_LOGGER = logging.getLogger('ignore')
45
NO_LOGGER.disabled = 1
46
47
PATHCHANGE = object()
48
49
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
50
51
52
class CircularDependencyError(Exception):
53
    """The ingredients of the current experiment form a circular dependency."""
54
55
56
class ObserverError(Exception):
57
    """Error that an observer raises but that should not make the run fail."""
58
59
60
class SacredInterrupt(Exception):
61
    """Base-Class for all custom interrupts.
62
63
    For more information see :ref:`custom_interrupts`.
64
    """
65
66
    STATUS = "INTERRUPTED"
67
68
69
class TimeoutInterrupt(SacredInterrupt):
70
    """Signal a that the experiment timed out.
71
72
    This exception can be used in client code to indicate that the run
73
    exceeded its time limit and has been interrupted because of that.
74
    The status of the interrupted run will then be set to ``TIMEOUT``.
75
76
    For more information see :ref:`custom_interrupts`.
77
    """
78
79
    STATUS = "TIMEOUT"
80
81
82
def create_basic_stream_logger():
83
    logger = logging.getLogger('')
84
    logger.setLevel(logging.INFO)
85
    logger.handlers = []
86
    ch = logging.StreamHandler()
87
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
88
    ch.setFormatter(formatter)
89
    logger.addHandler(ch)
90
    return logger
91
92
93
def recursive_update(d, u):
94
    """
95
    Given two dictionaries d and u, update dict d recursively.
96
97
    E.g.:
98
    d = {'a': {'b' : 1}}
99
    u = {'c': 2, 'a': {'d': 3}}
100
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
101
    """
102
    for k, v in u.items():
103
        if isinstance(v, collections.Mapping):
104
            r = recursive_update(d.get(k, {}), v)
105
            d[k] = r
106
        else:
107
            d[k] = u[k]
108
    return d
109
110
111
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
112
    """
113
    Recursively iterate over the items of a dictionary in a special order.
114
115
    First iterate over manually sorted keys and then over all items that are
116
    non-dictionary values (sorted by keys), then over the rest
117
    (sorted by keys), providing full dotted paths for every leaf.
118
    """
119
    if manually_sorted_keys is None:
120
        manually_sorted_keys = []
121
    for key in manually_sorted_keys:
122
        if key in dictionary:
123
            yield key, dictionary[key]
124
125
    single_line_keys = [key for key in dictionary.keys() if
126
                        key not in manually_sorted_keys and
127
                        (not dictionary[key] or
128
                         not isinstance(dictionary[key], dict))]
129
    for key in sorted(single_line_keys):
130
        yield key, dictionary[key]
131
132
    multi_line_keys = [key for key in dictionary.keys() if
133
                       key not in manually_sorted_keys and
134
                       (dictionary[key] and
135
                        isinstance(dictionary[key], dict))]
136
    for key in sorted(multi_line_keys):
137
        yield key, PATHCHANGE
138
        for k, val in iterate_flattened_separately(dictionary[key],
139
                                                   manually_sorted_keys):
140
            yield join_paths(key, k), val
141
142
143
def iterate_flattened(d):
144
    """
145
    Recursively iterate over the items of a dictionary.
146
147
    Provides a full dotted paths for every leaf.
148
    """
149
    for key in sorted(d.keys()):
150
        value = d[key]
151
        if isinstance(value, dict):
152
            for k, v in iterate_flattened(d[key]):
153
                yield join_paths(key, k), v
154
        else:
155
            yield key, value
156
157
158
def set_by_dotted_path(d, path, value):
159
    """
160
    Set an entry in a nested dict using a dotted path.
161
162
    Will create dictionaries as needed.
163
164
    Examples:
165
    >>> d = {'foo': {'bar': 7}}
166
    >>> set_by_dotted_path(d, 'foo.bar', 10)
167
    >>> d
168
    {'foo': {'bar': 10}}
169
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
170
    >>> d
171
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
172
    """
173
    split_path = path.split('.')
174
    current_option = d
175
    for p in split_path[:-1]:
176
        if p not in current_option:
177
            current_option[p] = dict()
178
        current_option = current_option[p]
179
    current_option[split_path[-1]] = value
180
181
182
def get_by_dotted_path(d, path, default=None):
183
    """
184
    Get an entry from nested dictionaries using a dotted path.
185
186
    Example:
187
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
188
    12
189
    """
190
    if not path:
191
        return d
192
    split_path = path.split('.')
193
    current_option = d
194
    for p in split_path:
195
        if p not in current_option:
196
            return default
197
        current_option = current_option[p]
198
    return current_option
199
200
201
def iter_path_splits(path):
202
    """
203
    Iterate over possible splits of a dotted path.
204
205
    The first part can be empty the second should not be.
206
207
    Example:
208
    >>> list(iter_path_splits('foo.bar.baz'))
209
    [('',        'foo.bar.baz'),
210
     ('foo',     'bar.baz'),
211
     ('foo.bar', 'baz')]
212
    """
213
    split_path = path.split('.')
214
    for i in range(len(split_path)):
215
        p1 = join_paths(*split_path[:i])
216
        p2 = join_paths(*split_path[i:])
217
        yield p1, p2
218
219
220
def iter_prefixes(path):
221
    """
222
    Iterate through all (non-empty) prefixes of a dotted path.
223
224
    Example:
225
    >>> list(iter_prefixes('foo.bar.baz'))
226
    ['foo', 'foo.bar', 'foo.bar.baz']
227
    """
228
    split_path = path.split('.')
229
    for i in range(1, len(split_path) + 1):
230
        yield join_paths(*split_path[:i])
231
232
233
def join_paths(*parts):
234
    """Join different parts together to a valid dotted path."""
235
    return '.'.join(str(p).strip('.') for p in parts if p)
236
237
238
def is_prefix(pre_path, path):
239
    """Return True if pre_path is a path-prefix of path."""
240
    pre_path = pre_path.strip('.')
241
    path = path.strip('.')
242
    return not pre_path or path.startswith(pre_path + '.')
243
244
245
def convert_to_nested_dict(dotted_dict):
246
    """Convert a dict with dotted path keys to corresponding nested dict."""
247
    nested_dict = {}
248
    for k, v in iterate_flattened(dotted_dict):
249
        set_by_dotted_path(nested_dict, k, v)
250
    return nested_dict
251
252
253
def print_filtered_stacktrace():
254
    exc_type, exc_value, exc_traceback = sys.exc_info()
255
    # determine if last exception is from sacred
256
    current_tb = exc_traceback
257
    while current_tb.tb_next is not None:
258
        current_tb = current_tb.tb_next
259
    if '__sacred__' in current_tb.tb_frame.f_globals:
260
        print("Exception originated from within Sacred.\n"
261
              "Traceback (most recent calls):", file=sys.stderr)
262
        tb.print_tb(exc_traceback)
263
        tb.print_exception(exc_type, exc_value, None)
264
    else:
265
        print("Traceback (most recent calls WITHOUT Sacred internals):",
266
              file=sys.stderr)
267
        current_tb = exc_traceback
268
        while current_tb is not None:
269
            if '__sacred__' not in current_tb.tb_frame.f_globals:
270
                tb.print_tb(current_tb, 1)
271
            current_tb = current_tb.tb_next
272
        print("\n".join(tb.format_exception_only(exc_type, exc_value)).strip(),
273
              file=sys.stderr)
274
275
276
def is_subdir(path, directory):
277
    path = os.path.abspath(os.path.realpath(path)) + os.sep
278
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
279
280
    return path.startswith(directory)
281
282
283
# noinspection PyUnusedLocal
284
@wrapt.decorator
285
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
286
    # here wrapped is itself a decorator
287
    if args:  # means it was used as a normal decorator (so just call it)
288
        return wrapped(*args, **kwargs)
289
    else:  # used with kwargs, so we need to return a decorator
290
        return partial(wrapped, **kwargs)
291
292
293
def get_inheritors(cls):
294
    """Get a set of all classes that inherit from the given class."""
295
    subclasses = set()
296
    work = [cls]
297
    while work:
298
        parent = work.pop()
299
        for child in parent.__subclasses__():
300
            if child not in subclasses:
301
                subclasses.add(child)
302
                work.append(child)
303
    return subclasses
304
305
306
# Credit to Zarathustra and epost from stackoverflow
307
# Taken from http://stackoverflow.com/a/1176023/1388435
308
def convert_camel_case_to_snake_case(name):
309
    """Convert CamelCase to snake_case."""
310
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
311
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
312
313
314
def apply_backspaces_and_linefeeds(text):
315
    """
316
    Interpret backspaces and linefeeds in text like a terminal would.
317
318
    Interpret text like a terminal by removing backspace and linefeed
319
    characters and applying them line by line.
320
321
    If final line ends with a carriage it keeps it to be concatenable with next
322
    output chunk.
323
    """
324
    orig_lines = text.split('\n')
325
    orig_lines_len = len(orig_lines)
326
    new_lines = []
327
    for orig_line_idx, orig_line in enumerate(orig_lines):
328
        chars, cursor = [], 0
329
        orig_line_len = len(orig_line)
330
        for orig_char_idx, orig_char in enumerate(orig_line):
331
            if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or
332
                                      orig_line_idx != orig_lines_len - 1):
333
                cursor = 0
334
            elif orig_char == '\b':
335
                cursor = max(0, cursor - 1)
336
            else:
337
                if (orig_char == '\r' and
338
                        orig_char_idx == orig_line_len - 1 and
339
                        orig_line_idx == orig_lines_len - 1):
340
                    cursor = len(chars)
341
                if cursor == len(chars):
342
                    chars.append(orig_char)
343
                else:
344
                    chars[cursor] = orig_char
345
                cursor += 1
346
        new_lines.append(''.join(chars))
347
    return '\n'.join(new_lines)
348
349
350
def module_exists(modname):
351
    """Checks if a module exists without actually importing it."""
352
    return pkgutil.find_loader(modname) is not None
353
354
355
def modules_exist(*modnames):
356
    return all(module_exists(m) for m in modnames)
357
358
359
def module_is_in_cache(modname):
360
    """Checks if a module was imported before (is in the import cache)."""
361
    return modname in sys.modules
362
363
364
def module_is_imported(modname, scope=None):
365
    """Checks if a module is imported within the current namespace."""
366
    # return early if modname is not even cached
367
    if not module_is_in_cache(modname):
368
        return False
369
370
    if scope is None:  # use globals() of the caller by default
371
        scope = inspect.stack()[1][0].f_globals
372
373
    for m in scope.values():
374
        if isinstance(m, type(sys)) and m.__name__ == modname:
375
            return True
376
377
    return False
378