Completed
Pull Request — master (#93)
by
unknown
29s
created

parse_columns()   B

Complexity

Conditions 5

Size

Total Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
dl 0
loc 9
rs 8.5454
c 0
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Miliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
59
PERCENTILE_COL_RX = re.compile(r'p(\d+(?:.\d+)?)')
60
61
62
class SecondsDecimal(Decimal):
63
    def __float__(self):
64
        return float(super(SecondsDecimal, self).__str__())
65
66
    def __str__(self):
67
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
68
69
    @property
70
    def as_string(self):
71
        return super(SecondsDecimal, self).__str__()
72
73
74
class NameWrapper(object):
75
    def __init__(self, target):
76
        self.target = target
77
78
    def __str__(self):
79
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
80
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
81
        return name
82
83
    def __repr__(self):
84
        return "NameWrapper(%s)" % repr(self.target)
85
86
87
def get_tag(project_name=None):
88
    info = get_commit_info(project_name)
89
    parts = [info['id'], get_current_time()]
90
    if info['dirty']:
91
        parts.append("uncommited-changes")
92
    return "_".join(parts)
93
94
95
def get_machine_id():
96
    return "%s-%s-%s-%s" % (
97
        platform.system(),
98
        platform.python_implementation(),
99
        ".".join(platform.python_version_tuple()[:2]),
100
        platform.architecture()[0]
101
    )
102
103
104
class Fallback(object):
105
    def __init__(self, *exceptions):
106
        self.functions = []
107
        self.exceptions = exceptions
108
109
    def __call__(self, *args, **kwargs):
110
        for func in self.functions:
111
            try:
112
                return func(*args, **kwargs)
113
            except self.exceptions:
114
                continue
115
        else:
116
            raise NotImplementedError("Must have an fallback implementation for %s" % self.functions[0])
117
118
    def register(self, other):
119
        self.functions.append(other)
120
        return self
121
122
123
get_project_name = Fallback(IndexError, CalledProcessError, OSError)
124
125
126
@get_project_name.register
127
def get_project_name_git():
128
    project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
129
    if isinstance(project_address, bytes) and str != bytes:
130
        project_address = project_address.decode()
131
    project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
132
    return project_name.strip()
133
134
135
@get_project_name.register
136
def get_project_name_hg():
137
    project_address = check_output(['hg', 'path', 'default'])
138
    project_address = project_address.decode()
139
    project_name = project_address.split("/")[-1]
140
    return project_name.strip()
141
142
143
@get_project_name.register
144
def get_project_name_default():
145
    return basename(os.getcwd())
146
147
148
def in_any_parent(name, path=None):
149
    prev = None
150
    if not path:
151
        path = os.getcwd()
152
    while path and prev != path and not exists(join(path, name)):
153
        prev = path
154
        path = dirname(path)
155
    return exists(join(path, name))
156
157
158
def subprocess_output(cmd):
159
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
160
161
162
def get_commit_info(project_name=None):
163
    dirty = False
164
    commit = 'unversioned'
165
    commit_time = None
166
    author_time = None
167
    project_name = project_name or get_project_name()
168
    branch = '(unknown)'
169
    try:
170
        if in_any_parent('.git'):
171
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
172
            desc = desc.split('-')
173
            if desc[-1].strip() == 'dirty':
174
                dirty = True
175
                desc.pop()
176
            commit = desc[-1].strip('g')
177
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
178
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
179
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
180
            if branch == 'HEAD':
181
                branch = '(detached head)'
182
        elif in_any_parent('.hg'):
183
            desc = subprocess_output('hg id --id --debug')
184
            if desc[-1] == '+':
185
                dirty = True
186
            commit = desc.strip('+')
187
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
188
            branch = subprocess_output('hg branch')
189
        return {
190
            'id': commit,
191
            'time': commit_time,
192
            'author_time': author_time,
193
            'dirty': dirty,
194
            'project': project_name,
195
            'branch': branch,
196
        }
197
    except Exception as exc:
198
        return {
199
            'id': 'unknown',
200
            'time': None,
201
            'author_time': None,
202
            'dirty': dirty,
203
            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
204
                     if isinstance(exc, CalledProcessError) else repr(exc),
205
            'project': project_name,
206
            'branch': branch,
207
        }
208
209
210
def get_current_time():
211
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
212
213
214
def first_or_value(obj, value):
215
    if obj:
216
        value, = obj
217
218
    return value
219
220
221
def short_filename(path, machine_id=None):
222
    parts = []
223
    try:
224
        last = len(path.parts) - 1
225
    except AttributeError:
226
        return str(path)
227
    for pos, part in enumerate(path.parts):
228
        if not pos and part == machine_id:
229
            continue
230
        if pos == last:
231
            part = part.rsplit('.', 1)[0]
232
            # if len(part) > 16:
233
            #     part = "%.13s..." % part
234
        parts.append(part)
235
    return '/'.join(parts)
236
237
238
def load_timer(string):
239
    if "." not in string:
240
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
241
    mod, attr = string.rsplit(".", 1)
242
    if mod == 'pep418':
243
        if PY3:
244
            import time
245
            return NameWrapper(getattr(time, attr))
246
        else:
247
            from . import pep418
248
            return NameWrapper(getattr(pep418, attr))
249
    else:
250
        __import__(mod)
251
        mod = sys.modules[mod]
252
        return NameWrapper(getattr(mod, attr))
253
254
255
class RegressionCheck(object):
256
    def __init__(self, field, threshold):
257
        self.field = field
258
        self.threshold = threshold
259
260
    def fails(self, current, compared):
261
        val = self.compute(current, compared)
262
        if val > self.threshold:
263
            return "Field %r has failed %s: %.9f > %.9f" % (
264
                self.field, self.__class__.__name__, val, self.threshold
265
            )
266
267
268
class PercentageRegressionCheck(RegressionCheck):
269
    def compute(self, current, compared):
270
        val = compared[self.field]
271
        if not val:
272
            return float("inf")
273
        return current[self.field] / val * 100 - 100
274
275
276
class DifferenceRegressionCheck(RegressionCheck):
277
    def compute(self, current, compared):
278
        return current[self.field] - compared[self.field]
279
280
281
def parse_compare_fail(string,
282
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
283
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
284
                                      '0-9]+)?))$')):
285
    m = rex.match(string)
286
    if m:
287
        g = m.groupdict()
288
        if g['percentage']:
289
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
290
        elif g['difference']:
291
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
292
293
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
294
295
296
def parse_warmup(string):
297
    string = string.lower().strip()
298
    if string == "auto":
299
        return platform.python_implementation() == "PyPy"
300
    elif string in ["off", "false", "no"]:
301
        return False
302
    elif string in ["on", "true", "yes", ""]:
303
        return True
304
    else:
305
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
306
307
308
def name_formatter_short(bench):
309
    name = bench["name"]
310
    if bench["source"]:
311
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
312
    if name.startswith("test_"):
313
        name = name[5:]
314
    return name
315
316
317
def name_formatter_normal(bench):
318
    name = bench["name"]
319
    if bench["source"]:
320
        parts = bench["source"].split('/')
321
        parts[-1] = parts[-1][:12]
322
        name = "%s (%s)" % (name, '/'.join(parts))
323
    return name
324
325
326
def name_formatter_long(bench):
327
    if bench["source"]:
328
        return "%(fullname)s (%(source)s)" % bench
329
    else:
330
        return bench["fullname"]
331
332
333
NAME_FORMATTERS = {
334
    "short": name_formatter_short,
335
    "normal": name_formatter_normal,
336
    "long": name_formatter_long,
337
}
338
339
340
def parse_name_format(string):
341
    string = string.lower().strip()
342
    if string in NAME_FORMATTERS:
343
        return string
344
    else:
345
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
346
347
348
def parse_timer(string):
349
    return str(load_timer(string))
350
351
352
def parse_sort(string):
353
    string = string.lower().strip()
354
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
355
        raise argparse.ArgumentTypeError(
356
            "Unacceptable value: %r. "
357
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
358
            "'stddev', 'name', 'fullname'." % string)
359
    return string
360
361
362
def parse_columns(string):
363
    columns = [str.strip(s) for s in string.lower().split(',')]
364
    invalid = set(columns) - set(ALLOWED_COLUMNS) - set(c for c in columns if PERCENTILE_COL_RX.match(c))
365
    if invalid:
366
        # there are extra items in columns!
367
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
368
        msg += "The only valid column names are: %s, pXX.XX" % ', '.join(ALLOWED_COLUMNS)
369
        raise argparse.ArgumentTypeError(msg)
370
    return columns
371
372
373
def parse_rounds(string):
374
    try:
375
        value = int(string)
376
    except ValueError as exc:
377
        raise argparse.ArgumentTypeError(exc)
378
    else:
379
        if value < 1:
380
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
381
        return value
382
383
384
def parse_seconds(string):
385
    try:
386
        return SecondsDecimal(string).as_string
387
    except Exception as exc:
388
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
389
390
391
def parse_save(string):
392
    if not string:
393
        raise argparse.ArgumentTypeError("Can't be empty.")
394
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
395
    if illegal:
396
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
397
    return string
398
399
400
def _parse_hosts(storage_url, netrc_file):
401
402
    # load creds from netrc file
403
    path = os.path.expanduser(netrc_file)
404
    creds = None
405
    if netrc_file and os.path.isfile(path):
406
        creds = netrc.netrc(path)
407
408
    # add creds to urls
409
    urls = []
410
    for netloc in storage_url.netloc.split(','):
411
        auth = ""
412
        if creds and '@' not in netloc:
413
            host = netloc.split(':').pop(0)
414
            res = creds.authenticators(host)
415
            if res:
416
                user, _, secret = res
417
                auth = "{user}:{secret}@".format(user=user, secret=secret)
418
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
419
                                                 netloc=netloc, auth=auth)
420
        urls.append(url)
421
    return urls
422
423
424
def parse_elasticsearch_storage(string, default_index="benchmark",
425
                                default_doctype="benchmark", netrc_file=''):
426
    storage_url = urlparse(string)
427
    hosts = _parse_hosts(storage_url, netrc_file)
428
    index = default_index
429
    doctype = default_doctype
430
    if storage_url.path and storage_url.path != "/":
431
        splitted = storage_url.path.strip("/").split("/")
432
        index = splitted[0]
433
        if len(splitted) >= 2:
434
            doctype = splitted[1]
435
    query = parse_qs(storage_url.query)
436
    try:
437
        project_name = query["project_name"][0]
438
    except KeyError:
439
        project_name = get_project_name()
440
    return hosts, index, doctype, project_name
441
442
443
def load_storage(storage, **kwargs):
444
    if "://" not in storage:
445
        storage = "file://" + storage
446
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
447
    if storage.startswith("file://"):
448
        from .storage.file import FileStorage
449
        return FileStorage(storage[len("file://"):], **kwargs)
450
    elif storage.startswith("elasticsearch+"):
451
        from .storage.elasticsearch import ElasticsearchStorage
452
        # TODO update benchmark_autosave
453
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
454
                                           netrc_file=netrc_file)
455
        return ElasticsearchStorage(*args, **kwargs)
456
    else:
457
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
458
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
459
460
461
def time_unit(value):
462
    if value < 1e-6:
463
        return "n", 1e9
464
    elif value < 1e-3:
465
        return "u", 1e6
466
    elif value < 1:
467
        return "m", 1e3
468
    else:
469
        return "", 1.
470
471
472
def operations_unit(value):
473
    if value > 1e+6:
474
        return "M", 1e-6
475
    if value > 1e+3:
476
        return "K", 1e-3
477
    return "", 1.
478
479
480
def format_time(value):
481
    unit, adjustment = time_unit(value)
482
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
483
484
485
class cached_property(object):
486
    def __init__(self, func):
487
        self.__doc__ = getattr(func, '__doc__')
488
        self.func = func
489
490
    def __get__(self, obj, cls):
491
        if obj is None:
492
            return self
493
        value = obj.__dict__[self.func.__name__] = self.func(obj)
494
        return value
495
496
497
def funcname(f):
498
    try:
499
        if isinstance(f, partial):
500
            return f.func.__name__
501
        else:
502
            return f.__name__
503
    except AttributeError:
504
        return str(f)
505
506
507
def clonefunc(f):
508
    """Deep clone the given function to create a new one.
509
510
    By default, the PyPy JIT specializes the assembler based on f.__code__:
511
    clonefunc makes sure that you will get a new function with a **different**
512
    __code__, so that PyPy will produce independent assembler. This is useful
513
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
514
    apples to apples.
515
516
    Use it with caution: if abused, this might easily produce an explosion of
517
    produced assembler.
518
519
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
520
    """
521
522
    # first of all, we clone the code object
523
    try:
524
        co = f.__code__
525
        if PY3:
526
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
527
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
528
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
529
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
530
        else:
531
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
532
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
533
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
534
        #
535
        # then, we clone the function itself, using the new co2
536
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
537
    except AttributeError:
538
        return f
539
540
541
def format_dict(obj):
542
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
543
544
545
class SafeJSONEncoder(json.JSONEncoder):
546
    def default(self, o):
547
        return "UNSERIALIZABLE[%r]" % o
548
549
550
def safe_dumps(obj, **kwargs):
551
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
552
553
554
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
555
    total = len(iterable)
556
557
    def progress_reporting_wrapper():
558
        for pos, item in enumerate(iterable):
559
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
560
            terminal_reporter.rewrite(string, black=True, bold=True)
561
            yield string, item
562
563
    return progress_reporting_wrapper()
564
565
566
def report_noprogress(iterable, *args, **kwargs):
567
    for pos, item in enumerate(iterable):
568
        yield "", item
569
570
571
def slugify(name):
572
    for c in "\/:*?<>| ":
573
        name = name.replace(c, '_').replace('__', '_')
574
    return name
575
576
577
def commonpath(paths):
578
    """Given a sequence of path names, returns the longest common sub-path."""
579
580
    if not paths:
581
        raise ValueError('commonpath() arg is an empty sequence')
582
583
    if isinstance(paths[0], bytes):
584
        sep = b'\\'
585
        altsep = b'/'
586
        curdir = b'.'
587
    else:
588
        sep = '\\'
589
        altsep = '/'
590
        curdir = '.'
591
592
    try:
593
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
594
        split_paths = [p.split(sep) for d, p in drivesplits]
595
596
        try:
597
            isabs, = set(p[:1] == sep for d, p in drivesplits)
598
        except ValueError:
599
            raise ValueError("Can't mix absolute and relative paths")
600
601
        # Check that all drive letters or UNC paths match. The check is made only
602
        # now otherwise type errors for mixing strings and bytes would not be
603
        # caught.
604
        if len(set(d for d, p in drivesplits)) != 1:
605
            raise ValueError("Paths don't have the same drive")
606
607
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
608
        common = path.split(sep)
609
        common = [c for c in common if c and c != curdir]
610
611
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
612
        s1 = min(split_paths)
613
        s2 = max(split_paths)
614
        for i, c in enumerate(s1):
615
            if c != s2[i]:
616
                common = common[:i]
617
                break
618
        else:
619
            common = common[:len(s1)]
620
621
        prefix = drive + sep if isabs else drive
622
        return prefix + sep.join(common)
623
    except (TypeError, AttributeError):
624
        genericpath._check_arg_types('commonpath', *paths)
625
        raise
626
627
628
def get_cprofile_functions(stats):
629
    """
630
    Convert pstats structure to list of sorted dicts about each function.
631
    """
632
    result = []
633
    # this assumes that you run py.test from project root dir
634
    project_dir_parent = dirname(os.getcwd())
635
636
    for function_info, run_info in stats.stats.items():
637
        file_path = function_info[0]
638
        if file_path.startswith(project_dir_parent):
639
            file_path = file_path[len(project_dir_parent):].lstrip('/')
640
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
641
642
        # if the function is recursive write number of 'total calls/primitive calls'
643
        if run_info[0] == run_info[1]:
644
            calls = str(run_info[0])
645
        else:
646
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
647
648
        result.append(dict(ncalls_recursion=calls,
649
                           ncalls=run_info[1],
650
                           tottime=run_info[2],
651
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
652
                           cumtime=run_info[3],
653
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
654
                           function_name=function_name))
655
656
    return result
657