Completed
Pull Request — master (#59)
by
unknown
01:30
created

get_cprofile_functions()   B

Complexity

Conditions 6

Size

Total Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 34
rs 7.5384
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import ntpath
8
import os
9
import platform
10
import re
11
import subprocess
12
import sys
13
import types
14
import operator
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
19
from .compat import PY3
20
21
try:
22
    from subprocess import check_output
23
except ImportError:
24
    def check_output(*popenargs, **kwargs):
25
        if 'stdout' in kwargs:
26
            raise ValueError('stdout argument not allowed, it will be overridden.')
27
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
28
        output, unused_err = process.communicate()
29
        retcode = process.poll()
30
        if retcode:
31
            cmd = kwargs.get("args")
32
            if cmd is None:
33
                cmd = popenargs[0]
34
            raise subprocess.CalledProcessError(retcode, cmd)
35
        return output
36
37
TIME_UNITS = {
38
    "": "Seconds",
39
    "m": "Miliseconds (ms)",
40
    "u": "Microseconds (us)",
41
    "n": "Nanoseconds (ns)"
42
}
43
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "rounds", "iterations"]
44
45
46
class SecondsDecimal(Decimal):
47
    def __float__(self):
48
        return float(super(SecondsDecimal, self).__str__())
49
50
    def __str__(self):
51
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
52
53
    @property
54
    def as_string(self):
55
        return super(SecondsDecimal, self).__str__()
56
57
58
class NameWrapper(object):
59
    def __init__(self, target):
60
        self.target = target
61
62
    def __str__(self):
63
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
64
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
65
        return name
66
67
    def __repr__(self):
68
        return "NameWrapper(%s)" % repr(self.target)
69
70
71
def get_tag():
72
    info = get_commit_info()
73
    return '%s_%s%s' % (info['id'], get_current_time(), '_uncommitted-changes' if info['dirty'] else '')
74
75
76
def get_machine_id():
77
    return "%s-%s-%s-%s" % (
78
        platform.system(),
79
        platform.python_implementation(),
80
        ".".join(platform.python_version_tuple()[:2]),
81
        platform.architecture()[0]
82
    )
83
84
85
def get_commit_info():
86
    dirty = False
87
    commit = 'unversioned'
88
    try:
89
        if os.path.exists('.git'):
90
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
91
                                universal_newlines=True).strip()
92
            desc = desc.split('-')
93
            if desc[-1].strip() == 'dirty':
94
                dirty = True
95
                desc.pop()
96
            commit = desc[-1].strip('g')
97
        elif os.path.exists('.hg'):
98
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
99
            if desc[-1] == '+':
100
                dirty = True
101
            commit = desc.strip('+')
102
        return {
103
            'id': commit,
104
            'dirty': dirty
105
        }
106
    except Exception as exc:
107
        return {
108
            'id': 'unknown',
109
            'dirty': dirty,
110
            'error': repr(exc),
111
        }
112
113
114
def get_current_time():
115
    return datetime.now().strftime("%Y%m%d_%H%M%S")
116
117
118
def first_or_value(obj, value):
119
    if obj:
120
        value, = obj
121
122
    return value
123
124
125
def short_filename(path, machine_id=None):
126
    parts = []
127
    last = len(path.parts) - 1
128
    for pos, part in enumerate(path.parts):
129
        if not pos and part == machine_id:
130
            continue
131
        if pos == last:
132
            part = part.rsplit('.', 1)[0]
133
            # if len(part) > 16:
134
            #     part = "%.13s..." % part
135
        parts.append(part)
136
    return '/'.join(parts)
137
138
139
def load_timer(string):
140
    if "." not in string:
141
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
142
    mod, attr = string.rsplit(".", 1)
143
    if mod == 'pep418':
144
        if PY3:
145
            import time
146
            return NameWrapper(getattr(time, attr))
147
        else:
148
            from . import pep418
149
            return NameWrapper(getattr(pep418, attr))
150
    else:
151
        __import__(mod)
152
        mod = sys.modules[mod]
153
        return NameWrapper(getattr(mod, attr))
154
155
156
class RegressionCheck(object):
157
    def __init__(self, field, threshold):
158
        self.field = field
159
        self.threshold = threshold
160
161
    def fails(self, current, compared):
162
        val = self.compute(current, compared)
163
        if val > self.threshold:
164
            return "Field %r has failed %s: %.9f > %.9f" % (
165
                self.field, self.__class__.__name__, val, self.threshold
166
            )
167
168
169
class PercentageRegressionCheck(RegressionCheck):
170
    def compute(self, current, compared):
171
        val = compared[self.field]
172
        if not val:
173
            return float("inf")
174
        return current[self.field] / val * 100 - 100
175
176
177
class DifferenceRegressionCheck(RegressionCheck):
178
    def compute(self, current, compared):
179
        return current[self.field] - compared[self.field]
180
181
182
def parse_compare_fail(string,
183
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
184
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?))$')):
185
    m = rex.match(string)
186
    if m:
187
        g = m.groupdict()
188
        if g['percentage']:
189
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
190
        elif g['difference']:
191
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
192
193
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
194
195
196
def parse_warmup(string):
197
    string = string.lower().strip()
198
    if string == "auto":
199
        return platform.python_implementation() == "PyPy"
200
    elif string in ["off", "false", "no"]:
201
        return False
202
    elif string in ["on", "true", "yes", ""]:
203
        return True
204
    else:
205
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
206
207
208
def name_formatter_short(bench):
209
    name = bench["name"]
210
    if bench["source"]:
211
        name = "%s (%.4s)" % (name, os.path.split(bench["source"])[-1])
212
    if name.startswith("test_"):
213
        name = name[5:]
214
    return name
215
216
217
def name_formatter_normal(bench):
218
    name = bench["name"]
219
    if bench["source"]:
220
        parts = bench["source"].split('/')
221
        parts[-1] = parts[-1][:12]
222
        name = "%s (%s)" % (name, '/'.join(parts))
223
    return name
224
225
226
def name_formatter_long(bench):
227
    if bench["source"]:
228
        return "%(fullname)s (%(source)s)" % bench
229
    else:
230
        return bench["fullname"]
231
232
233
NAME_FORMATTERS = {
234
    "short": name_formatter_short,
235
    "normal": name_formatter_normal,
236
    "long": name_formatter_long,
237
}
238
239
240
def parse_name_format(string):
241
    string = string.lower().strip()
242
    if string in NAME_FORMATTERS:
243
        return string
244
    else:
245
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
246
247
248
def parse_timer(string):
249
    return str(load_timer(string))
250
251
252
def parse_sort(string):
253
    string = string.lower().strip()
254
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
255
        raise argparse.ArgumentTypeError(
256
            "Unacceptable value: %r. "
257
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
258
            "'stddev', 'name', 'fullname'." % string)
259
    return string
260
261
262
def parse_columns(string):
263
    columns = [str.strip(s) for s in string.lower().split(',')]
264
    invalid = set(columns) - set(ALLOWED_COLUMNS)
265
    if invalid:
266
        # there are extra items in columns!
267
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
268
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
269
        raise argparse.ArgumentTypeError(msg)
270
    return columns
271
272
273
def parse_rounds(string):
274
    try:
275
        value = int(string)
276
    except ValueError as exc:
277
        raise argparse.ArgumentTypeError(exc)
278
    else:
279
        if value < 1:
280
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
281
        return value
282
283
284
def parse_seconds(string):
285
    try:
286
        return SecondsDecimal(string).as_string
287
    except Exception as exc:
288
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
289
290
291
def parse_save(string):
292
    if not string:
293
        raise argparse.ArgumentTypeError("Can't be empty.")
294
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
295
    if illegal:
296
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
297
    return string
298
299
300
def time_unit(value):
301
    if value < 1e-6:
302
        return "n", 1e9
303
    elif value < 1e-3:
304
        return "u", 1e6
305
    elif value < 1:
306
        return "m", 1e3
307
    else:
308
        return "", 1.
309
310
311
def format_time(value):
312
    unit, adjustment = time_unit(value)
313
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
314
315
316
class cached_property(object):
317
    def __init__(self, func):
318
        self.__doc__ = getattr(func, '__doc__')
319
        self.func = func
320
321
    def __get__(self, obj, cls):
322
        if obj is None:
323
            return self
324
        value = obj.__dict__[self.func.__name__] = self.func(obj)
325
        return value
326
327
328
def funcname(f):
329
    try:
330
        if isinstance(f, partial):
331
            return f.func.__name__
332
        else:
333
            return f.__name__
334
    except AttributeError:
335
        return str(f)
336
337
338
def clonefunc(f):
339
    """Deep clone the given function to create a new one.
340
341
    By default, the PyPy JIT specializes the assembler based on f.__code__:
342
    clonefunc makes sure that you will get a new function with a **different**
343
    __code__, so that PyPy will produce independent assembler. This is useful
344
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
345
    apples to apples.
346
347
    Use it with caution: if abused, this might easily produce an explosion of
348
    produced assembler.
349
350
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
351
    """
352
353
    # first of all, we clone the code object
354
    try:
355
        co = f.__code__
356
        if PY3:
357
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
358
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
359
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
360
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
361
        else:
362
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
363
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
364
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
365
        #
366
        # then, we clone the function itself, using the new co2
367
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
368
    except AttributeError:
369
        return f
370
371
372
def format_dict(obj):
373
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
374
375
376
class SafeJSONEncoder(json.JSONEncoder):
377
    def default(self, o):
378
        return "UNSERIALIZABLE[%r]" % o
379
380
381
def safe_dumps(obj, **kwargs):
382
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
383
384
385
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
386
    total = len(iterable)
387
388
    def progress_reporting_wrapper():
389
        for pos, item in enumerate(iterable):
390
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
391
            terminal_reporter.rewrite(string, black=True, bold=True)
392
            yield string, item
393
    return progress_reporting_wrapper()
394
395
396
def report_noprogress(iterable, *args, **kwargs):
397
    for pos, item in enumerate(iterable):
398
        yield "", item
399
400
401
def slugify(name):
402
    for c in "\/:*?<>| ":
403
        name = name.replace(c, '_').replace('__', '_')
404
    return name
405
406
407
def commonpath(paths):
408
    """Given a sequence of path names, returns the longest common sub-path."""
409
410
    if not paths:
411
        raise ValueError('commonpath() arg is an empty sequence')
412
413
    if isinstance(paths[0], bytes):
414
        sep = b'\\'
415
        altsep = b'/'
416
        curdir = b'.'
417
    else:
418
        sep = '\\'
419
        altsep = '/'
420
        curdir = '.'
421
422
    try:
423
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
424
        split_paths = [p.split(sep) for d, p in drivesplits]
425
426
        try:
427
            isabs, = set(p[:1] == sep for d, p in drivesplits)
428
        except ValueError:
429
            raise ValueError("Can't mix absolute and relative paths")
430
431
        # Check that all drive letters or UNC paths match. The check is made only
432
        # now otherwise type errors for mixing strings and bytes would not be
433
        # caught.
434
        if len(set(d for d, p in drivesplits)) != 1:
435
            raise ValueError("Paths don't have the same drive")
436
437
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
438
        common = path.split(sep)
439
        common = [c for c in common if c and c != curdir]
440
441
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
442
        s1 = min(split_paths)
443
        s2 = max(split_paths)
444
        for i, c in enumerate(s1):
445
            if c != s2[i]:
446
                common = common[:i]
447
                break
448
        else:
449
            common = common[:len(s1)]
450
451
        prefix = drive + sep if isabs else drive
452
        return prefix + sep.join(common)
453
    except (TypeError, AttributeError):
454
        genericpath._check_arg_types('commonpath', *paths)
455
        raise
456
457
458
def get_cprofile_functions(stats, sort_by='cumtime', reverse=True):
459
    """
460
    convert pstats structure to list of sorted dicts about each function.
461
    you can sort by these keys: ncallls_recursion, ncalls, tottime, tottime_per,
462
    cumtime, cumtime_per, function_name. sort direction can be influenced by
463
    `reverse`.
464
    """
465
    result = []
466
    # this assumes that you run py.test from project root dir
467
    project_dir_parent = os.path.dirname(os.getcwd())
468
469
    for function_info, run_info in stats.stats.items():
470
        file_path = function_info[0]
471
        if file_path.startswith(project_dir_parent):
472
            file_path = file_path[len(project_dir_parent):].lstrip('/')
473
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
474
475
        # if the function is recursive write number of 'total calls/primitive calls'
476
        if run_info[0] == run_info[1]:
477
            calls = str(run_info[0])
478
        else:
479
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
480
481
        result.append(dict(ncalls_recursion=calls,
482
                           ncalls=run_info[1],
483
                           tottime=run_info[2],
484
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
485
                           cumtime=run_info[3],
486
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
487
                           function_name=function_name))
488
489
    result.sort(key=operator.itemgetter(sort_by), reverse=reverse)
490
491
    return result
492