Completed
Push — master ( dbc38f...56accc )
by Klaus
01:34
created

iterate_flattened_separately()   F

Complexity

Conditions 15

Size

Total Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 15
dl 0
loc 30
rs 2.7451
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like iterate_flattened_separately() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import logging
7
import os.path
8
import re
9
import subprocess
10
import sys
11
import traceback as tb
12
from functools import partial
13
from contextlib import contextmanager
14
15
import wrapt
16
17
from sacred.optional import libc
18
19
__sacred__ = True  # marks files that should be filtered from stack traces
20
21
NO_LOGGER = logging.getLogger('ignore')
22
NO_LOGGER.disabled = 1
23
24
PATHCHANGE = object()
25
26
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
27
28
# A PY2 compatible FileNotFoundError
29
if sys.version_info[0] == 2:
30
    import errno
31
32
    class FileNotFoundError(IOError):
33
        def __init__(self, msg):
34
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
35
else:
36
    # Reassign so that we can import it from here
37
    FileNotFoundError = FileNotFoundError
38
39
40
def flush():
41
    """Try to flush all stdio buffers, both from python and from C."""
42
    try:
43
        sys.stdout.flush()
44
        sys.stderr.flush()
45
    except (AttributeError, ValueError, IOError):
46
        pass  # unsupported
47
    try:
48
        libc.fflush(None)
49
    except (AttributeError, ValueError, IOError):
50
        pass  # unsupported
51
52
53
class CircularDependencyError(Exception):
54
    """The ingredients of the current experiment form a circular dependency."""
55
56
57
class ObserverError(Exception):
58
    """Error that an observer raises but that should not make the run fail."""
59
60
61
class SacredInterrupt(Exception):
62
    """Base-Class for all custom interrupts.
63
64
    For more information see :ref:`custom_interrupts`.
65
    """
66
67
    STATUS = "INTERRUPTED"
68
69
70
class TimeoutInterrupt(SacredInterrupt):
71
    """Signal a that the experiment timed out.
72
73
    This exception can be used in client code to indicate that the run
74
    exceeded its time limit and has been interrupted because of that.
75
    The status of the interrupted run will then be set to ``TIMEOUT``.
76
77
    For more information see :ref:`custom_interrupts`.
78
    """
79
80
    STATUS = "TIMEOUT"
81
82
83
def create_basic_stream_logger():
84
    logger = logging.getLogger('')
85
    logger.setLevel(logging.INFO)
86
    logger.handlers = []
87
    ch = logging.StreamHandler()
88
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
89
    ch.setFormatter(formatter)
90
    logger.addHandler(ch)
91
    return logger
92
93
94
def recursive_update(d, u):
95
    """
96
    Given two dictionaries d and u, update dict d recursively.
97
98
    E.g.:
99
    d = {'a': {'b' : 1}}
100
    u = {'c': 2, 'a': {'d': 3}}
101
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
102
    """
103
    for k, v in u.items():
104
        if isinstance(v, collections.Mapping):
105
            r = recursive_update(d.get(k, {}), v)
106
            d[k] = r
107
        else:
108
            d[k] = u[k]
109
    return d
110
111
112
# Duplicate stdout and stderr to a file. Inspired by:
113
# http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
114
# http://stackoverflow.com/a/651718/1388435
115
# http://stackoverflow.com/a/22434262/1388435
116
@contextmanager
117
def tee_output(target):
118
    original_stdout_fd = 1
119
    original_stderr_fd = 2
120
121
    # Save a copy of the original stdout and stderr file descriptors
122
    saved_stdout_fd = os.dup(original_stdout_fd)
123
    saved_stderr_fd = os.dup(original_stderr_fd)
124
125
    target_fd = target.fileno()
126
127
    final_output = []
128
129
    try:
130
        try:
131
            tee_stdout = subprocess.Popen(
132
                ['tee', '-a', '/dev/stderr'],
133
                stdin=subprocess.PIPE, stderr=target_fd, stdout=1)
134
            tee_stderr = subprocess.Popen(
135
                ['tee', '-a', '/dev/stderr'],
136
                stdin=subprocess.PIPE, stderr=target_fd, stdout=2)
137
        except (FileNotFoundError, OSError):
138
            tee_stdout = subprocess.Popen(
139
                [sys.executable, "-m", "sacred.pytee"],
140
                stdin=subprocess.PIPE, stderr=target_fd)
141
            tee_stderr = subprocess.Popen(
142
                [sys.executable, "-m", "sacred.pytee"],
143
                stdin=subprocess.PIPE, stdout=target_fd)
144
145
        flush()
146
        os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd)
147
        os.dup2(tee_stderr.stdin.fileno(), original_stderr_fd)
148
149
        yield final_output  # let the caller do their printing
150
        flush()
151
152
        # then redirect stdout back to the saved fd
153
        tee_stdout.stdin.close()
154
        tee_stderr.stdin.close()
155
156
        # restore original fds
157
        os.dup2(saved_stdout_fd, original_stdout_fd)
158
        os.dup2(saved_stderr_fd, original_stderr_fd)
159
160
        tee_stdout.wait()
161
        tee_stderr.wait()
162
    finally:
163
        os.close(saved_stdout_fd)
164
        os.close(saved_stderr_fd)
165
        target.flush()
166
        target.seek(0)
167
        final_output.append(target.read().decode())
168
169
170
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
171
    """
172
    Recursively iterate over the items of a dictionary in a special order.
173
174
    First iterate over manually sorted keys and then over all items that are
175
    non-dictionary values (sorted by keys), then over the rest
176
    (sorted by keys), providing full dotted paths for every leaf.
177
    """
178
    if manually_sorted_keys is None:
179
        manually_sorted_keys = []
180
    for key in manually_sorted_keys:
181
        if key in dictionary:
182
            yield key, dictionary[key]
183
184
    single_line_keys = [key for key in dictionary.keys() if
185
                        key not in manually_sorted_keys and
186
                        (not dictionary[key] or
187
                         not isinstance(dictionary[key], dict))]
188
    for key in sorted(single_line_keys):
189
        yield key, dictionary[key]
190
191
    multi_line_keys = [key for key in dictionary.keys() if
192
                       key not in manually_sorted_keys and
193
                       (dictionary[key] and
194
                        isinstance(dictionary[key], dict))]
195
    for key in sorted(multi_line_keys):
196
        yield key, PATHCHANGE
197
        for k, val in iterate_flattened_separately(dictionary[key],
198
                                                   manually_sorted_keys):
199
            yield join_paths(key, k), val
200
201
202
def iterate_flattened(d):
203
    """
204
    Recursively iterate over the items of a dictionary.
205
206
    Provides a full dotted paths for every leaf.
207
    """
208
    for key in sorted(d.keys()):
209
        value = d[key]
210
        if isinstance(value, dict):
211
            for k, v in iterate_flattened(d[key]):
212
                yield join_paths(key, k), v
213
        else:
214
            yield key, value
215
216
217
def set_by_dotted_path(d, path, value):
218
    """
219
    Set an entry in a nested dict using a dotted path.
220
221
    Will create dictionaries as needed.
222
223
    Examples:
224
    >>> d = {'foo': {'bar': 7}}
225
    >>> set_by_dotted_path(d, 'foo.bar', 10)
226
    >>> d
227
    {'foo': {'bar': 10}}
228
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
229
    >>> d
230
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
231
    """
232
    split_path = path.split('.')
233
    current_option = d
234
    for p in split_path[:-1]:
235
        if p not in current_option:
236
            current_option[p] = dict()
237
        current_option = current_option[p]
238
    current_option[split_path[-1]] = value
239
240
241
def get_by_dotted_path(d, path):
242
    """
243
    Get an entry from nested dictionaries using a dotted path.
244
245
    Example:
246
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
247
    12
248
    """
249
    if not path:
250
        return d
251
    split_path = path.split('.')
252
    current_option = d
253
    for p in split_path:
254
        if p not in current_option:
255
            return None
256
        current_option = current_option[p]
257
    return current_option
258
259
260
def iter_path_splits(path):
261
    """
262
    Iterate over possible splits of a dotted path.
263
264
    The first part can be empty the second should not be.
265
266
    Example:
267
    >>> list(iter_path_splits('foo.bar.baz'))
268
    [('',        'foo.bar.baz'),
269
     ('foo',     'bar.baz'),
270
     ('foo.bar', 'baz')]
271
    """
272
    split_path = path.split('.')
273
    for i in range(len(split_path)):
274
        p1 = join_paths(*split_path[:i])
275
        p2 = join_paths(*split_path[i:])
276
        yield p1, p2
277
278
279
def iter_prefixes(path):
280
    """
281
    Iterate through all (non-empty) prefixes of a dotted path.
282
283
    Example:
284
    >>> list(iter_prefixes('foo.bar.baz'))
285
    ['foo', 'foo.bar', 'foo.bar.baz']
286
    """
287
    split_path = path.split('.')
288
    for i in range(1, len(split_path) + 1):
289
        yield join_paths(*split_path[:i])
290
291
292
def join_paths(*parts):
293
    """Join different parts together to a valid dotted path."""
294
    return '.'.join(p.strip('.') for p in parts if p)
295
296
297
def is_prefix(pre_path, path):
298
    """Return True if pre_path is a path-prefix of path."""
299
    pre_path = pre_path.strip('.')
300
    path = path.strip('.')
301
    return not pre_path or path.startswith(pre_path + '.')
302
303
304
def convert_to_nested_dict(dotted_dict):
305
    """Convert a dict with dotted path keys to corresponding nested dict."""
306
    nested_dict = {}
307
    for k, v in iterate_flattened(dotted_dict):
308
        set_by_dotted_path(nested_dict, k, v)
309
    return nested_dict
310
311
312
def print_filtered_stacktrace():
313
    exc_type, exc_value, exc_traceback = sys.exc_info()
314
    # determine if last exception is from sacred
315
    current_tb = exc_traceback
316
    while current_tb.tb_next is not None:
317
        current_tb = current_tb.tb_next
318
    if '__sacred__' in current_tb.tb_frame.f_globals:
319
        print("Exception originated from within Sacred.\n"
320
              "Traceback (most recent calls):", file=sys.stderr)
321
        tb.print_tb(exc_traceback)
322
        tb.print_exception(exc_type, exc_value, None)
323
    else:
324
        print("Traceback (most recent calls WITHOUT Sacred internals):",
325
              file=sys.stderr)
326
        current_tb = exc_traceback
327
        while current_tb is not None:
328
            if '__sacred__' not in current_tb.tb_frame.f_globals:
329
                tb.print_tb(current_tb, 1)
330
            current_tb = current_tb.tb_next
331
        print("\n".join(tb.format_exception_only(exc_type, exc_value)).strip(),
332
              file=sys.stderr)
333
334
335
def is_subdir(path, directory):
336
    path = os.path.abspath(os.path.realpath(path)) + os.sep
337
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
338
339
    return path.startswith(directory)
340
341
342
# noinspection PyUnusedLocal
343
@wrapt.decorator
344
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
345
    # here wrapped is itself a decorator
346
    if args:  # means it was used as a normal decorator (so just call it)
347
        return wrapped(*args, **kwargs)
348
    else:  # used with kwargs, so we need to return a decorator
349
        return partial(wrapped, **kwargs)
350
351
352
def get_inheritors(cls):
353
    """Get a set of all classes that inherit from the given class."""
354
    subclasses = set()
355
    work = [cls]
356
    while work:
357
        parent = work.pop()
358
        for child in parent.__subclasses__():
359
            if child not in subclasses:
360
                subclasses.add(child)
361
                work.append(child)
362
    return subclasses
363
364
365
# Credit to Zarathustra and epost from stackoverflow
366
# Taken from http://stackoverflow.com/a/1176023/1388435
367
def convert_camel_case_to_snake_case(name):
368
    """Convert CamelCase to snake_case."""
369
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
370
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
371
372
373
def apply_backspaces_and_linefeeds(text):
374
    """
375
    Interpret backspaces and linefeeds in text like a terminal would.
376
377
    Interpret text like a terminal by removing backspace and linefeed
378
    characters and applying them line by line.
379
    """
380
    lines = []
381
    for line in text.split('\n'):
382
        chars, cursor = [], 0
383
        for ch in line:
384
            if ch == '\b':
385
                cursor = max(0, cursor - 1)
386
            elif ch == '\r':
387
                cursor = 0
388
            else:
389
                # normal character
390
                if cursor == len(chars):
391
                    chars.append(ch)
392
                else:
393
                    chars[cursor] = ch
394
                cursor += 1
395
        lines.append(''.join(chars))
396
    return '\n'.join(lines)
397
398
399
# Code adapted from here:
400
# https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/
401
def natural_sort(l):
402
    def alphanum_key(key):
403
        return [int(c) if c.isdigit() else c.lower()
404
                for c in re.split('([0-9]+)', key)]
405
406
    return sorted(l, key=alphanum_key)
407