Completed
Push — master ( 163501...3dcdfa )
by Ionel Cristian
03:38
created

Fallback.__call__()   B

Complexity

Conditions 6

Size

Total Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
dl 0
loc 11
rs 8
c 0
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Milliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
59
60
61
class SecondsDecimal(Decimal):
62
    def __float__(self):
63
        return float(super(SecondsDecimal, self).__str__())
64
65
    def __str__(self):
66
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
67
68
    @property
69
    def as_string(self):
70
        return super(SecondsDecimal, self).__str__()
71
72
73
class NameWrapper(object):
74
    def __init__(self, target):
75
        self.target = target
76
77
    def __str__(self):
78
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
79
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
80
        return name
81
82
    def __repr__(self):
83
        return "NameWrapper(%s)" % repr(self.target)
84
85
86
def get_tag(project_name=None):
87
    info = get_commit_info(project_name)
88
    parts = [info['id'], get_current_time()]
89
    if info['dirty']:
90
        parts.append("uncommited-changes")
91
    return "_".join(parts)
92
93
94
def get_machine_id():
95
    return "%s-%s-%s-%s" % (
96
        platform.system(),
97
        platform.python_implementation(),
98
        ".".join(platform.python_version_tuple()[:2]),
99
        platform.architecture()[0]
100
    )
101
102
103
class Fallback(object):
104
    def __init__(self, fallback, exceptions):
105
        self.fallback = fallback
106
        self.functions = []
107
        self.exceptions = exceptions
108
109
    def __call__(self, *args, **kwargs):
110
        for func in self.functions:
111
            try:
112
                value = func(*args, **kwargs)
113
            except self.exceptions:
114
                continue
115
            else:
116
                if not value:
117
                    continue
118
        else:
119
            return self.fallback(*args, **kwargs)
120
121
    def register(self, other):
122
        self.functions.append(other)
123
        return self
124
125
126
@partial(Fallback, exceptions=(IndexError, CalledProcessError, OSError))
127
def get_project_name():
128
    return basename(os.getcwd())
129
130
131
@get_project_name.register
132
def get_project_name_git():
133
    is_git = check_output(['git', 'rev-parse', '--git-dir'], stderr=subprocess.STDOUT)
134
    if is_git:
135
        project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
136
        if isinstance(project_address, bytes) and str != bytes:
137
            project_address = project_address.decode()
138
        project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
139
        return project_name.strip()
140
141
142
@get_project_name.register
143
def get_project_name_hg():
144
    project_address = check_output(['hg', 'path', 'default', '--quiet'])
145
    project_address = project_address.decode()
146
    project_name = project_address.split("/")[-1]
147
    return project_name.strip()
148
149
150
def in_any_parent(name, path=None):
151
    prev = None
152
    if not path:
153
        path = os.getcwd()
154
    while path and prev != path and not exists(join(path, name)):
155
        prev = path
156
        path = dirname(path)
157
    return exists(join(path, name))
158
159
160
def subprocess_output(cmd):
161
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
162
163
164
def get_commit_info(project_name=None):
165
    dirty = False
166
    commit = 'unversioned'
167
    commit_time = None
168
    author_time = None
169
    project_name = project_name or get_project_name()
170
    branch = '(unknown)'
171
    try:
172
        if in_any_parent('.git'):
173
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
174
            desc = desc.split('-')
175
            if desc[-1].strip() == 'dirty':
176
                dirty = True
177
                desc.pop()
178
            commit = desc[-1].strip('g')
179
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
180
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
181
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
182
            if branch == 'HEAD':
183
                branch = '(detached head)'
184
        elif in_any_parent('.hg'):
185
            desc = subprocess_output('hg id --id --debug')
186
            if desc[-1] == '+':
187
                dirty = True
188
            commit = desc.strip('+')
189
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
190
            branch = subprocess_output('hg branch')
191
        return {
192
            'id': commit,
193
            'time': commit_time,
194
            'author_time': author_time,
195
            'dirty': dirty,
196
            'project': project_name,
197
            'branch': branch,
198
        }
199
    except Exception as exc:
200
        return {
201
            'id': 'unknown',
202
            'time': None,
203
            'author_time': None,
204
            'dirty': dirty,
205
            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
206
                     if isinstance(exc, CalledProcessError) else repr(exc),
207
            'project': project_name,
208
            'branch': branch,
209
        }
210
211
212
def get_current_time():
213
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
214
215
216
def first_or_value(obj, value):
217
    if obj:
218
        value, = obj
219
220
    return value
221
222
223
def short_filename(path, machine_id=None):
224
    parts = []
225
    try:
226
        last = len(path.parts) - 1
227
    except AttributeError:
228
        return str(path)
229
    for pos, part in enumerate(path.parts):
230
        if not pos and part == machine_id:
231
            continue
232
        if pos == last:
233
            part = part.rsplit('.', 1)[0]
234
            # if len(part) > 16:
235
            #     part = "%.13s..." % part
236
        parts.append(part)
237
    return '/'.join(parts)
238
239
240
def load_timer(string):
241
    if "." not in string:
242
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
243
    mod, attr = string.rsplit(".", 1)
244
    if mod == 'pep418':
245
        if PY3:
246
            import time
247
            return NameWrapper(getattr(time, attr))
248
        else:
249
            from . import pep418
250
            return NameWrapper(getattr(pep418, attr))
251
    else:
252
        __import__(mod)
253
        mod = sys.modules[mod]
254
        return NameWrapper(getattr(mod, attr))
255
256
257
class RegressionCheck(object):
258
    def __init__(self, field, threshold):
259
        self.field = field
260
        self.threshold = threshold
261
262
    def fails(self, current, compared):
263
        val = self.compute(current, compared)
264
        if val > self.threshold:
265
            return "Field %r has failed %s: %.9f > %.9f" % (
266
                self.field, self.__class__.__name__, val, self.threshold
267
            )
268
269
270
class PercentageRegressionCheck(RegressionCheck):
271
    def compute(self, current, compared):
272
        val = compared[self.field]
273
        if not val:
274
            return float("inf")
275
        return current[self.field] / val * 100 - 100
276
277
278
class DifferenceRegressionCheck(RegressionCheck):
279
    def compute(self, current, compared):
280
        return current[self.field] - compared[self.field]
281
282
283
def parse_compare_fail(string,
284
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
285
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
286
                                      '0-9]+)?))$')):
287
    m = rex.match(string)
288
    if m:
289
        g = m.groupdict()
290
        if g['percentage']:
291
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
292
        elif g['difference']:
293
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
294
295
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
296
297
298
def parse_warmup(string):
299
    string = string.lower().strip()
300
    if string == "auto":
301
        return platform.python_implementation() == "PyPy"
302
    elif string in ["off", "false", "no"]:
303
        return False
304
    elif string in ["on", "true", "yes", ""]:
305
        return True
306
    else:
307
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
308
309
310
def name_formatter_short(bench):
311
    name = bench["name"]
312
    if bench["source"]:
313
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
314
    if name.startswith("test_"):
315
        name = name[5:]
316
    return name
317
318
319
def name_formatter_normal(bench):
320
    name = bench["name"]
321
    if bench["source"]:
322
        parts = bench["source"].split('/')
323
        parts[-1] = parts[-1][:12]
324
        name = "%s (%s)" % (name, '/'.join(parts))
325
    return name
326
327
328
def name_formatter_long(bench):
329
    if bench["source"]:
330
        return "%(fullname)s (%(source)s)" % bench
331
    else:
332
        return bench["fullname"]
333
334
335
NAME_FORMATTERS = {
336
    "short": name_formatter_short,
337
    "normal": name_formatter_normal,
338
    "long": name_formatter_long,
339
}
340
341
342
def parse_name_format(string):
343
    string = string.lower().strip()
344
    if string in NAME_FORMATTERS:
345
        return string
346
    else:
347
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
348
349
350
def parse_timer(string):
351
    return str(load_timer(string))
352
353
354
def parse_sort(string):
355
    string = string.lower().strip()
356
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
357
        raise argparse.ArgumentTypeError(
358
            "Unacceptable value: %r. "
359
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
360
            "'stddev', 'name', 'fullname'." % string)
361
    return string
362
363
364
def parse_columns(string):
365
    columns = [str.strip(s) for s in string.lower().split(',')]
366
    invalid = set(columns) - set(ALLOWED_COLUMNS)
367
    if invalid:
368
        # there are extra items in columns!
369
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
370
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
371
        raise argparse.ArgumentTypeError(msg)
372
    return columns
373
374
375
def parse_rounds(string):
376
    try:
377
        value = int(string)
378
    except ValueError as exc:
379
        raise argparse.ArgumentTypeError(exc)
380
    else:
381
        if value < 1:
382
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
383
        return value
384
385
386
def parse_seconds(string):
387
    try:
388
        return SecondsDecimal(string).as_string
389
    except Exception as exc:
390
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
391
392
393
def parse_save(string):
394
    if not string:
395
        raise argparse.ArgumentTypeError("Can't be empty.")
396
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
397
    if illegal:
398
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
399
    return string
400
401
402
def _parse_hosts(storage_url, netrc_file):
403
404
    # load creds from netrc file
405
    path = os.path.expanduser(netrc_file)
406
    creds = None
407
    if netrc_file and os.path.isfile(path):
408
        creds = netrc.netrc(path)
409
410
    # add creds to urls
411
    urls = []
412
    for netloc in storage_url.netloc.split(','):
413
        auth = ""
414
        if creds and '@' not in netloc:
415
            host = netloc.split(':').pop(0)
416
            res = creds.authenticators(host)
417
            if res:
418
                user, _, secret = res
419
                auth = "{user}:{secret}@".format(user=user, secret=secret)
420
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
421
                                                 netloc=netloc, auth=auth)
422
        urls.append(url)
423
    return urls
424
425
426
def parse_elasticsearch_storage(string, default_index="benchmark",
427
                                default_doctype="benchmark", netrc_file=''):
428
    storage_url = urlparse(string)
429
    hosts = _parse_hosts(storage_url, netrc_file)
430
    index = default_index
431
    doctype = default_doctype
432
    if storage_url.path and storage_url.path != "/":
433
        splitted = storage_url.path.strip("/").split("/")
434
        index = splitted[0]
435
        if len(splitted) >= 2:
436
            doctype = splitted[1]
437
    query = parse_qs(storage_url.query)
438
    try:
439
        project_name = query["project_name"][0]
440
    except KeyError:
441
        project_name = get_project_name()
442
    return hosts, index, doctype, project_name
443
444
445
def load_storage(storage, **kwargs):
446
    if "://" not in storage:
447
        storage = "file://" + storage
448
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
449
    if storage.startswith("file://"):
450
        from .storage.file import FileStorage
451
        return FileStorage(storage[len("file://"):], **kwargs)
452
    elif storage.startswith("elasticsearch+"):
453
        from .storage.elasticsearch import ElasticsearchStorage
454
        # TODO update benchmark_autosave
455
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
456
                                           netrc_file=netrc_file)
457
        return ElasticsearchStorage(*args, **kwargs)
458
    else:
459
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
460
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
461
462
463
def time_unit(value):
464
    if value < 1e-6:
465
        return "n", 1e9
466
    elif value < 1e-3:
467
        return "u", 1e6
468
    elif value < 1:
469
        return "m", 1e3
470
    else:
471
        return "", 1.
472
473
474
def operations_unit(value):
475
    if value > 1e+6:
476
        return "M", 1e-6
477
    if value > 1e+3:
478
        return "K", 1e-3
479
    return "", 1.
480
481
482
def format_time(value):
483
    unit, adjustment = time_unit(value)
484
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
485
486
487
class cached_property(object):
488
    def __init__(self, func):
489
        self.__doc__ = getattr(func, '__doc__')
490
        self.func = func
491
492
    def __get__(self, obj, cls):
493
        if obj is None:
494
            return self
495
        value = obj.__dict__[self.func.__name__] = self.func(obj)
496
        return value
497
498
499
def funcname(f):
500
    try:
501
        if isinstance(f, partial):
502
            return f.func.__name__
503
        else:
504
            return f.__name__
505
    except AttributeError:
506
        return str(f)
507
508
509
def clonefunc(f):
510
    """Deep clone the given function to create a new one.
511
512
    By default, the PyPy JIT specializes the assembler based on f.__code__:
513
    clonefunc makes sure that you will get a new function with a **different**
514
    __code__, so that PyPy will produce independent assembler. This is useful
515
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
516
    apples to apples.
517
518
    Use it with caution: if abused, this might easily produce an explosion of
519
    produced assembler.
520
521
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
522
    """
523
524
    # first of all, we clone the code object
525
    try:
526
        co = f.__code__
527
        if PY3:
528
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
529
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
530
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
531
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
532
        else:
533
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
534
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
535
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
536
        #
537
        # then, we clone the function itself, using the new co2
538
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
539
    except AttributeError:
540
        return f
541
542
543
def format_dict(obj):
544
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
545
546
547
class SafeJSONEncoder(json.JSONEncoder):
548
    def default(self, o):
549
        return "UNSERIALIZABLE[%r]" % o
550
551
552
def safe_dumps(obj, **kwargs):
553
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
554
555
556
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
557
    total = len(iterable)
558
559
    def progress_reporting_wrapper():
560
        for pos, item in enumerate(iterable):
561
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
562
            terminal_reporter.rewrite(string, black=True, bold=True)
563
            yield string, item
564
565
    return progress_reporting_wrapper()
566
567
568
def report_noprogress(iterable, *args, **kwargs):
569
    for pos, item in enumerate(iterable):
570
        yield "", item
571
572
573
def slugify(name):
574
    for c in "\/:*?<>| ":
575
        name = name.replace(c, '_').replace('__', '_')
576
    return name
577
578
579
def commonpath(paths):
580
    """Given a sequence of path names, returns the longest common sub-path."""
581
582
    if not paths:
583
        raise ValueError('commonpath() arg is an empty sequence')
584
585
    if isinstance(paths[0], bytes):
586
        sep = b'\\'
587
        altsep = b'/'
588
        curdir = b'.'
589
    else:
590
        sep = '\\'
591
        altsep = '/'
592
        curdir = '.'
593
594
    try:
595
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
596
        split_paths = [p.split(sep) for d, p in drivesplits]
597
598
        try:
599
            isabs, = set(p[:1] == sep for d, p in drivesplits)
600
        except ValueError:
601
            raise ValueError("Can't mix absolute and relative paths")
602
603
        # Check that all drive letters or UNC paths match. The check is made only
604
        # now otherwise type errors for mixing strings and bytes would not be
605
        # caught.
606
        if len(set(d for d, p in drivesplits)) != 1:
607
            raise ValueError("Paths don't have the same drive")
608
609
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
610
        common = path.split(sep)
611
        common = [c for c in common if c and c != curdir]
612
613
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
614
        s1 = min(split_paths)
615
        s2 = max(split_paths)
616
        for i, c in enumerate(s1):
617
            if c != s2[i]:
618
                common = common[:i]
619
                break
620
        else:
621
            common = common[:len(s1)]
622
623
        prefix = drive + sep if isabs else drive
624
        return prefix + sep.join(common)
625
    except (TypeError, AttributeError):
626
        genericpath._check_arg_types('commonpath', *paths)
627
        raise
628
629
630
def get_cprofile_functions(stats):
631
    """
632
    Convert pstats structure to list of sorted dicts about each function.
633
    """
634
    result = []
635
    # this assumes that you run py.test from project root dir
636
    project_dir_parent = dirname(os.getcwd())
637
638
    for function_info, run_info in stats.stats.items():
639
        file_path = function_info[0]
640
        if file_path.startswith(project_dir_parent):
641
            file_path = file_path[len(project_dir_parent):].lstrip('/')
642
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
643
644
        # if the function is recursive write number of 'total calls/primitive calls'
645
        if run_info[0] == run_info[1]:
646
            calls = str(run_info[0])
647
        else:
648
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
649
650
        result.append(dict(ncalls_recursion=calls,
651
                           ncalls=run_info[1],
652
                           tottime=run_info[2],
653
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
654
                           cumtime=run_info[3],
655
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
656
                           function_name=function_name))
657
658
    return result
659