Completed
Push — master ( e81713...095af2 )
by Ionel Cristian
46s
created

src.pytest_benchmark.SafeJSONEncoder.default()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 2
rs 10
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import json
6
import os
7
import platform
8
import re
9
import subprocess
10
import sys
11
import types
12
from datetime import datetime
13
from decimal import Decimal
14
from functools import partial
15
16
from .compat import PY3
17
18
try:
19
    from subprocess import check_output
20
except ImportError:
21
    def check_output(*popenargs, **kwargs):
22
        if 'stdout' in kwargs:
23
            raise ValueError('stdout argument not allowed, it will be overridden.')
24
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
25
        output, unused_err = process.communicate()
26
        retcode = process.poll()
27
        if retcode:
28
            cmd = kwargs.get("args")
29
            if cmd is None:
30
                cmd = popenargs[0]
31
            raise subprocess.CalledProcessError(retcode, cmd)
32
        return output
33
34
TIME_UNITS = {
35
    "": "Seconds",
36
    "m": "Miliseconds (ms)",
37
    "u": "Microseconds (us)",
38
    "n": "Nanoseconds (ns)"
39
}
40
41
42
class SecondsDecimal(Decimal):
43
    def __float__(self):
44
        return float(super(SecondsDecimal, self).__str__())
45
46
    def __str__(self):
47
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
48
49
    @property
50
    def as_string(self):
51
        return super(SecondsDecimal, self).__str__()
52
53
54
class NameWrapper(object):
55
    def __init__(self, target):
56
        self.target = target
57
58
    def __str__(self):
59
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
60
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
61
        return name
62
63
    def __repr__(self):
64
        return "NameWrapper(%s)" % repr(self.target)
65
66
67
def get_tag():
68
    info = get_commit_info()
69
    return '%s_%s%s' % (info['id'], get_current_time(), '_uncommitted-changes' if info['dirty'] else '')
70
71
72
def get_machine_id():
73
    return "%s-%s-%s-%s" % (
74
        platform.system(),
75
        platform.python_implementation(),
76
        ".".join(platform.python_version_tuple()[:2]),
77
        platform.architecture()[0]
78
    )
79
80
81
def get_commit_info():
82
    dirty = False
83
    commit = 'unversioned'
84
    try:
85
        if os.path.exists('.git'):
86
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
87
                                universal_newlines=True).strip()
88
            desc = desc.split('-')
89
            if desc[-1].strip() == 'dirty':
90
                dirty = True
91
                desc.pop()
92
            commit = desc[-1].strip('g')
93
        elif os.path.exists('.hg'):
94
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
95
            if desc[-1] == '+':
96
                dirty = True
97
            commit = desc.strip('+')
98
        return {
99
            'id': commit,
100
            'dirty': dirty
101
        }
102
    except Exception as exc:
103
        return {
104
            'id': 'unknown',
105
            'dirty': dirty,
106
            'error': repr(exc),
107
        }
108
109
110
def get_current_time():
111
    return datetime.now().strftime("%Y%m%d_%H%M%S")
112
113
114
def first_or_value(obj, value):
115
    if obj:
116
        value, = obj
117
118
    return value
119
120
121
def short_filename(path, machine_id=None):
122
    parts = []
123
    last = len(path.parts) - 1
124
    for pos, part in enumerate(path.parts):
125
        if not pos and part == machine_id:
126
            continue
127
        if pos == last:
128
            part = part.split('_')[0]
129
        parts.append(part)
130
    return '/'.join(parts)
131
132
133
def load_timer(string):
134
    if "." not in string:
135
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
136
    mod, attr = string.rsplit(".", 1)
137
    if mod == 'pep418':
138
        if PY3:
139
            import time
140
            return NameWrapper(getattr(time, attr))
141
        else:
142
            from . import pep418
143
            return NameWrapper(getattr(pep418, attr))
144
    else:
145
        __import__(mod)
146
        mod = sys.modules[mod]
147
        return NameWrapper(getattr(mod, attr))
148
149
150
class RegressionCheck(object):
151
    def __init__(self, field, threshold):
152
        self.field = field
153
        self.threshold = threshold
154
155
    def fails(self, current, compared):
156
        val = self.compute(current, compared)
157
        if val > self.threshold:
158
            return "Field %r has failed %s: %.9f > %.9f" % (
159
                self.field, self.__class__.__name__, val, self.threshold
160
            )
161
162
163
class PercentageRegressionCheck(RegressionCheck):
164
    def compute(self, current, compared):
165
        val = compared[self.field]
166
        if not val:
167
            return float("inf")
168
        return current[self.field] / val * 100 - 100
169
170
171
class DifferenceRegressionCheck(RegressionCheck):
172
    def compute(self, current, compared):
173
        return current[self.field] - compared[self.field]
174
175
176
def parse_compare_fail(string,
177
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
178
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?))$')):
179
    m = rex.match(string)
180
    if m:
181
        g = m.groupdict()
182
        if g['percentage']:
183
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
184
        elif g['difference']:
185
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
186
187
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
188
189
190
def parse_warmup(string):
191
    string = string.lower().strip()
192
    if string == "auto":
193
        return platform.python_implementation() == "PyPy"
194
    elif string in ["off", "false", "no"]:
195
        return False
196
    elif string in ["on", "true", "yes", ""]:
197
        return True
198
    else:
199
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
200
201
202
def parse_timer(string):
203
    return str(load_timer(string))
204
205
206
def parse_sort(string):
207
    string = string.lower().strip()
208
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
209
        raise argparse.ArgumentTypeError(
210
            "Unacceptable value: %r. "
211
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
212
            "'stddev', 'name', 'fullname'." % string)
213
    return string
214
215
216
def parse_columns(string):
217
    allowed_columns = ["min", "max", "mean", "stddev", "median", "iqr",
218
                       "outliers", "rounds", "iterations"]
219
    columns = [str.strip(s) for s in string.lower().split(',')]
220
    invalid = set(columns) - set(allowed_columns)
221
    if invalid:
222
        # there are extra items in columns!
223
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
224
        msg += "The only valid column names are: %s" % ', '.join(allowed_columns)
225
        raise argparse.ArgumentTypeError(msg)
226
    return columns
227
228
229
def parse_rounds(string):
230
    try:
231
        value = int(string)
232
    except ValueError as exc:
233
        raise argparse.ArgumentTypeError(exc)
234
    else:
235
        if value < 1:
236
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
237
        return value
238
239
240
def parse_seconds(string):
241
    try:
242
        return SecondsDecimal(string).as_string
243
    except Exception as exc:
244
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
245
246
247
def parse_save(string):
248
    if not string:
249
        raise argparse.ArgumentTypeError("Can't be empty.")
250
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
251
    if illegal:
252
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
253
    return string
254
255
256
def time_unit(value):
257
    if value < 1e-6:
258
        return "n", 1e9
259
    elif value < 1e-3:
260
        return "u", 1e6
261
    elif value < 1:
262
        return "m", 1e3
263
    else:
264
        return "", 1.
265
266
267
def format_time(value):
268
    unit, adjustment = time_unit(value)
269
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
270
271
272
class cached_property(object):
273
    def __init__(self, func):
274
        self.__doc__ = getattr(func, '__doc__')
275
        self.func = func
276
277
    def __get__(self, obj, cls):
278
        if obj is None:
279
            return self
280
        value = obj.__dict__[self.func.__name__] = self.func(obj)
281
        return value
282
283
284
def funcname(f):
285
    try:
286
        if isinstance(f, partial):
287
            return f.func.__name__
288
        else:
289
            return f.__name__
290
    except AttributeError:
291
        return str(f)
292
293
294
def clonefunc(f):
295
    """Deep clone the given function to create a new one.
296
297
    By default, the PyPy JIT specializes the assembler based on f.__code__:
298
    clonefunc makes sure that you will get a new function with a **different**
299
    __code__, so that PyPy will produce independent assembler. This is useful
300
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
301
    apples to apples.
302
303
    Use it with caution: if abused, this might easily produce an explosion of
304
    produced assembler.
305
306
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
307
    """
308
309
    # first of all, we clone the code object
310
    try:
311
        co = f.__code__
312
        if PY3:
313
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
314
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
315
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
316
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
317
        else:
318
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
319
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
320
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
321
        #
322
        # then, we clone the function itself, using the new co2
323
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
324
    except AttributeError:
325
        return f
326
327
328
def format_dict(obj):
329
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
330
331
332
class SafeJSONEncoder(json.JSONEncoder):
333
    def default(self, o):
334
        return "UNSERIALIZABLE[%r]" % o
335
336
337
def safe_dumps(obj, **kwargs):
338
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
339
340
341
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
342
    total = len(iterable)
343
344
    def progress_reporting_wrapper():
345
        for pos, item in enumerate(iterable):
346
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
347
            terminal_reporter.rewrite(string, black=True, bold=True)
348
            yield string, item
349
    return progress_reporting_wrapper()
350
351
352
def report_noprogress(iterable, *args, **kwargs):
353
    for pos, item in enumerate(iterable):
354
        yield "", item
355
356
357
def slugify(name):
358
    for c in "\/:*?<>| ":
359
        name = name.replace(c, '_').replace('__', '_')
360
    return name
361
362
363
def annotate_source(bench, source):
364
    bench.update(
365
        canonical_name=bench["name"],
366
        canonical_fullname=bench["fullname"],
367
    )
368
    if source:
369
        bench.update(
370
            name="{0} ({1})".format(bench["name"], source),
371
            fullname="{0} ({1})".format(bench["fullname"], source),
372
        )
373
        bench["source"] = source
374
    return bench
375