Completed
Push — master ( 3e335d...193896 )
by Klaus
01:10
created

apply_backspaces_and_linefeeds()   F

Complexity

Conditions 11

Size

Total Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 1
Metric Value
cc 11
c 2
b 0
f 1
dl 0
loc 34
rs 3.1764

How to fix   Complexity   

Complexity

Complex classes like apply_backspaces_and_linefeeds() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import collections
6
import logging
7
import os.path
8
import re
9
import sys
10
import traceback as tb
11
from functools import partial
12
13
import wrapt
14
15
16
__sacred__ = True  # marks files that should be filtered from stack traces
17
18
__all__ = ["NO_LOGGER", "PYTHON_IDENTIFIER", "CircularDependencyError",
19
           "ObserverError", "SacredInterrupt", "TimeoutInterrupt",
20
           "create_basic_stream_logger", "recursive_update",
21
           "iterate_flattened", "iterate_flattened_separately",
22
           "set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
23
           "iter_prefixes", "join_paths", "is_prefix",
24
           "convert_to_nested_dict", "convert_camel_case_to_snake_case",
25
           "print_filtered_stacktrace", "is_subdir",
26
           "optional_kwargs_decorator", "get_inheritors",
27
           "apply_backspaces_and_linefeeds", "StringIO", "FileNotFoundError"]
28
29
# A PY2 compatible FileNotFoundError
30
if sys.version_info[0] == 2:
31
    import errno
32
33
    class FileNotFoundError(IOError):
34
        def __init__(self, msg):
35
            super(FileNotFoundError, self).__init__(errno.ENOENT, msg)
36
    from StringIO import StringIO
37
else:
38
    # Reassign so that we can import it from here
39
    FileNotFoundError = FileNotFoundError
40
    from io import StringIO
41
42
NO_LOGGER = logging.getLogger('ignore')
43
NO_LOGGER.disabled = 1
44
45
PATHCHANGE = object()
46
47
PYTHON_IDENTIFIER = re.compile("^[a-zA-Z_][_a-zA-Z0-9]*$")
48
49
50
class CircularDependencyError(Exception):
51
    """The ingredients of the current experiment form a circular dependency."""
52
53
54
class ObserverError(Exception):
55
    """Error that an observer raises but that should not make the run fail."""
56
57
58
class SacredInterrupt(Exception):
59
    """Base-Class for all custom interrupts.
60
61
    For more information see :ref:`custom_interrupts`.
62
    """
63
64
    STATUS = "INTERRUPTED"
65
66
67
class TimeoutInterrupt(SacredInterrupt):
68
    """Signal a that the experiment timed out.
69
70
    This exception can be used in client code to indicate that the run
71
    exceeded its time limit and has been interrupted because of that.
72
    The status of the interrupted run will then be set to ``TIMEOUT``.
73
74
    For more information see :ref:`custom_interrupts`.
75
    """
76
77
    STATUS = "TIMEOUT"
78
79
80
def create_basic_stream_logger():
81
    logger = logging.getLogger('')
82
    logger.setLevel(logging.INFO)
83
    logger.handlers = []
84
    ch = logging.StreamHandler()
85
    formatter = logging.Formatter('%(levelname)s - %(name)s - %(message)s')
86
    ch.setFormatter(formatter)
87
    logger.addHandler(ch)
88
    return logger
89
90
91
def recursive_update(d, u):
92
    """
93
    Given two dictionaries d and u, update dict d recursively.
94
95
    E.g.:
96
    d = {'a': {'b' : 1}}
97
    u = {'c': 2, 'a': {'d': 3}}
98
    => {'a': {'b': 1, 'd': 3}, 'c': 2}
99
    """
100
    for k, v in u.items():
101
        if isinstance(v, collections.Mapping):
102
            r = recursive_update(d.get(k, {}), v)
103
            d[k] = r
104
        else:
105
            d[k] = u[k]
106
    return d
107
108
109
def iterate_flattened_separately(dictionary, manually_sorted_keys=None):
110
    """
111
    Recursively iterate over the items of a dictionary in a special order.
112
113
    First iterate over manually sorted keys and then over all items that are
114
    non-dictionary values (sorted by keys), then over the rest
115
    (sorted by keys), providing full dotted paths for every leaf.
116
    """
117
    if manually_sorted_keys is None:
118
        manually_sorted_keys = []
119
    for key in manually_sorted_keys:
120
        if key in dictionary:
121
            yield key, dictionary[key]
122
123
    single_line_keys = [key for key in dictionary.keys() if
124
                        key not in manually_sorted_keys and
125
                        (not dictionary[key] or
126
                         not isinstance(dictionary[key], dict))]
127
    for key in sorted(single_line_keys):
128
        yield key, dictionary[key]
129
130
    multi_line_keys = [key for key in dictionary.keys() if
131
                       key not in manually_sorted_keys and
132
                       (dictionary[key] and
133
                        isinstance(dictionary[key], dict))]
134
    for key in sorted(multi_line_keys):
135
        yield key, PATHCHANGE
136
        for k, val in iterate_flattened_separately(dictionary[key],
137
                                                   manually_sorted_keys):
138
            yield join_paths(key, k), val
139
140
141
def iterate_flattened(d):
142
    """
143
    Recursively iterate over the items of a dictionary.
144
145
    Provides a full dotted paths for every leaf.
146
    """
147
    for key in sorted(d.keys()):
148
        value = d[key]
149
        if isinstance(value, dict):
150
            for k, v in iterate_flattened(d[key]):
151
                yield join_paths(key, k), v
152
        else:
153
            yield key, value
154
155
156
def set_by_dotted_path(d, path, value):
157
    """
158
    Set an entry in a nested dict using a dotted path.
159
160
    Will create dictionaries as needed.
161
162
    Examples:
163
    >>> d = {'foo': {'bar': 7}}
164
    >>> set_by_dotted_path(d, 'foo.bar', 10)
165
    >>> d
166
    {'foo': {'bar': 10}}
167
    >>> set_by_dotted_path(d, 'foo.d.baz', 3)
168
    >>> d
169
    {'foo': {'bar': 10, 'd': {'baz': 3}}}
170
    """
171
    split_path = path.split('.')
172
    current_option = d
173
    for p in split_path[:-1]:
174
        if p not in current_option:
175
            current_option[p] = dict()
176
        current_option = current_option[p]
177
    current_option[split_path[-1]] = value
178
179
180
def get_by_dotted_path(d, path, default=None):
181
    """
182
    Get an entry from nested dictionaries using a dotted path.
183
184
    Example:
185
    >>> get_by_dotted_path({'foo': {'a': 12}}, 'foo.a')
186
    12
187
    """
188
    if not path:
189
        return d
190
    split_path = path.split('.')
191
    current_option = d
192
    for p in split_path:
193
        if p not in current_option:
194
            return default
195
        current_option = current_option[p]
196
    return current_option
197
198
199
def iter_path_splits(path):
200
    """
201
    Iterate over possible splits of a dotted path.
202
203
    The first part can be empty the second should not be.
204
205
    Example:
206
    >>> list(iter_path_splits('foo.bar.baz'))
207
    [('',        'foo.bar.baz'),
208
     ('foo',     'bar.baz'),
209
     ('foo.bar', 'baz')]
210
    """
211
    split_path = path.split('.')
212
    for i in range(len(split_path)):
213
        p1 = join_paths(*split_path[:i])
214
        p2 = join_paths(*split_path[i:])
215
        yield p1, p2
216
217
218
def iter_prefixes(path):
219
    """
220
    Iterate through all (non-empty) prefixes of a dotted path.
221
222
    Example:
223
    >>> list(iter_prefixes('foo.bar.baz'))
224
    ['foo', 'foo.bar', 'foo.bar.baz']
225
    """
226
    split_path = path.split('.')
227
    for i in range(1, len(split_path) + 1):
228
        yield join_paths(*split_path[:i])
229
230
231
def join_paths(*parts):
232
    """Join different parts together to a valid dotted path."""
233
    return '.'.join(str(p).strip('.') for p in parts if p)
234
235
236
def is_prefix(pre_path, path):
237
    """Return True if pre_path is a path-prefix of path."""
238
    pre_path = pre_path.strip('.')
239
    path = path.strip('.')
240
    return not pre_path or path.startswith(pre_path + '.')
241
242
243
def convert_to_nested_dict(dotted_dict):
244
    """Convert a dict with dotted path keys to corresponding nested dict."""
245
    nested_dict = {}
246
    for k, v in iterate_flattened(dotted_dict):
247
        set_by_dotted_path(nested_dict, k, v)
248
    return nested_dict
249
250
251
def print_filtered_stacktrace():
252
    exc_type, exc_value, exc_traceback = sys.exc_info()
253
    # determine if last exception is from sacred
254
    current_tb = exc_traceback
255
    while current_tb.tb_next is not None:
256
        current_tb = current_tb.tb_next
257
    if '__sacred__' in current_tb.tb_frame.f_globals:
258
        print("Exception originated from within Sacred.\n"
259
              "Traceback (most recent calls):", file=sys.stderr)
260
        tb.print_tb(exc_traceback)
261
        tb.print_exception(exc_type, exc_value, None)
262
    else:
263
        print("Traceback (most recent calls WITHOUT Sacred internals):",
264
              file=sys.stderr)
265
        current_tb = exc_traceback
266
        while current_tb is not None:
267
            if '__sacred__' not in current_tb.tb_frame.f_globals:
268
                tb.print_tb(current_tb, 1)
269
            current_tb = current_tb.tb_next
270
        print("\n".join(tb.format_exception_only(exc_type, exc_value)).strip(),
271
              file=sys.stderr)
272
273
274
def is_subdir(path, directory):
275
    path = os.path.abspath(os.path.realpath(path)) + os.sep
276
    directory = os.path.abspath(os.path.realpath(directory)) + os.sep
277
278
    return path.startswith(directory)
279
280
281
# noinspection PyUnusedLocal
282
@wrapt.decorator
283
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
284
    # here wrapped is itself a decorator
285
    if args:  # means it was used as a normal decorator (so just call it)
286
        return wrapped(*args, **kwargs)
287
    else:  # used with kwargs, so we need to return a decorator
288
        return partial(wrapped, **kwargs)
289
290
291
def get_inheritors(cls):
292
    """Get a set of all classes that inherit from the given class."""
293
    subclasses = set()
294
    work = [cls]
295
    while work:
296
        parent = work.pop()
297
        for child in parent.__subclasses__():
298
            if child not in subclasses:
299
                subclasses.add(child)
300
                work.append(child)
301
    return subclasses
302
303
304
# Credit to Zarathustra and epost from stackoverflow
305
# Taken from http://stackoverflow.com/a/1176023/1388435
306
def convert_camel_case_to_snake_case(name):
307
    """Convert CamelCase to snake_case."""
308
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
309
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
310
311
312
def apply_backspaces_and_linefeeds(text):
313
    """
314
    Interpret backspaces and linefeeds in text like a terminal would.
315
316
    Interpret text like a terminal by removing backspace and linefeed
317
    characters and applying them line by line.
318
319
    If final line ends with a carriage it keeps it to be concatenable with next
320
    output chunk.
321
    """
322
    orig_lines = text.split('\n')
323
    orig_lines_len = len(orig_lines)
324
    new_lines = []
325
    for orig_line_idx, orig_line in enumerate(orig_lines):
326
        chars, cursor = [], 0
327
        orig_line_len = len(orig_line)
328
        for orig_char_idx, orig_char in enumerate(orig_line):
329
            if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or
330
                                      orig_line_idx != orig_lines_len - 1):
331
                cursor = 0
332
            elif orig_char == '\b':
333
                cursor = max(0, cursor - 1)
334
            else:
335
                if (orig_char == '\r' and
336
                        orig_char_idx == orig_line_len - 1 and
337
                        orig_line_idx == orig_lines_len - 1):
338
                    cursor = len(chars)
339
                if cursor == len(chars):
340
                    chars.append(orig_char)
341
                else:
342
                    chars[cursor] = orig_char
343
                cursor += 1
344
        new_lines.append(''.join(chars))
345
    return '\n'.join(new_lines)
346