Completed
Pull Request — master (#78)
by
unknown
01:17
created

operations_unit()   A

Complexity

Conditions 3

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 1
Metric Value
cc 3
dl 0
loc 6
rs 9.4285
c 1
b 0
f 1
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Miliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
59
60
61
class SecondsDecimal(Decimal):
62
    def __float__(self):
63
        return float(super(SecondsDecimal, self).__str__())
64
65
    def __str__(self):
66
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
67
68
    @property
69
    def as_string(self):
70
        return super(SecondsDecimal, self).__str__()
71
72
73
class NameWrapper(object):
74
    def __init__(self, target):
75
        self.target = target
76
77
    def __str__(self):
78
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
79
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
80
        return name
81
82
    def __repr__(self):
83
        return "NameWrapper(%s)" % repr(self.target)
84
85
86
def get_tag(project_name=None):
87
    info = get_commit_info(project_name)
88
    parts = [info['id'], get_current_time()]
89
    if info['dirty']:
90
        parts.append("uncommited-changes")
91
    return "_".join(parts)
92
93
94
def get_machine_id():
95
    return "%s-%s-%s-%s" % (
96
        platform.system(),
97
        platform.python_implementation(),
98
        ".".join(platform.python_version_tuple()[:2]),
99
        platform.architecture()[0]
100
    )
101
102
103
class Fallback(object):
104
    def __init__(self, *exceptions):
105
        self.functions = []
106
        self.exceptions = exceptions
107
108
    def __call__(self, *args, **kwargs):
109
        for func in self.functions:
110
            try:
111
                return func(*args, **kwargs)
112
            except self.exceptions:
113
                continue
114
        else:
115
            raise NotImplementedError("Must have an fallback implementation for %s" % self.functions[0])
116
117
    def register(self, other):
118
        self.functions.append(other)
119
        return self
120
121
122
get_project_name = Fallback(IndexError, CalledProcessError, OSError)
123
124
125
@get_project_name.register
126
def get_project_name_git():
127
    project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
128
    if isinstance(project_address, bytes) and str != bytes:
129
        project_address = project_address.decode()
130
    project_name = [i for i in re.split(r'[/:\s]|\.git', project_address) if i][-1]
131
    return project_name.strip()
132
133
134
@get_project_name.register
135
def get_project_name_hg():
136
    project_address = check_output(['hg', 'path', 'default'])
137
    project_address = project_address.decode()
138
    project_name = project_address.split("/")[-1]
139
    return project_name.strip()
140
141
142
@get_project_name.register
143
def get_project_name_default():
144
    return basename(os.getcwd())
145
146
147
def in_any_parent(name, path=None):
148
    prev = None
149
    if not path:
150
        path = os.getcwd()
151
    while path and prev != path and not exists(join(path, name)):
152
        prev = path
153
        path = dirname(path)
154
    return exists(join(path, name))
155
156
157
def get_branch_info():
158
    def cmd(s):
159
        args = s.split()
160
        return check_output(args, stderr=subprocess.STDOUT, universal_newlines=True)
161
162
    try:
163
        if in_any_parent('.git'):
164
            branch = cmd('git rev-parse --abbrev-ref HEAD').strip()
165
            if branch == 'HEAD':
166
                return '(detached head)'
167
            return branch
168
        elif in_any_parent('.hg'):
169
            return cmd('hg branch').strip()
170
        else:
171
            return '(unknown vcs)'
172
    except subprocess.CalledProcessError as e:
173
        return '(error: %s)' % e.output.strip()
174
175
176
def get_commit_info(project_name=None):
177
    dirty = False
178
    commit = 'unversioned'
179
    commit_time = None
180
    author_time = None
181
    project_name = project_name or get_project_name()
182
    branch = get_branch_info()
183
    try:
184
        if in_any_parent('.git'):
185
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
186
                                universal_newlines=True).strip()
187
            desc = desc.split('-')
188
            if desc[-1].strip() == 'dirty':
189
                dirty = True
190
                desc.pop()
191
            commit = desc[-1].strip('g')
192
            commit_time = check_output('git show -s --pretty=format:"%cI"'.split(),
193
                                       universal_newlines=True).strip().strip('"')
194
            author_time = check_output('git show -s --pretty=format:"%aI"'.split(),
195
                                       universal_newlines=True).strip().strip('"')
196
        elif in_any_parent('.hg'):
197
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
198
            if desc[-1] == '+':
199
                dirty = True
200
            commit = desc.strip('+')
201
            commit_time = check_output('hg tip --template "{date|rfc3339date}"'.split(),
202
                                       universal_newlines=True).strip().strip('"')
203
        return {
204
            'id': commit,
205
            'time': commit_time,
206
            'author_time': author_time,
207
            'dirty': dirty,
208
            'project': project_name,
209
            'branch': branch,
210
        }
211
    except Exception as exc:
212
        return {
213
            'id': 'unknown',
214
            'time': None,
215
            'author_time': None,
216
            'dirty': dirty,
217
            'error': repr(exc),
218
            'project': project_name,
219
            'branch': branch,
220
        }
221
222
223
def get_current_time():
224
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
225
226
227
def first_or_value(obj, value):
228
    if obj:
229
        value, = obj
230
231
    return value
232
233
234
def short_filename(path, machine_id=None):
235
    parts = []
236
    try:
237
        last = len(path.parts) - 1
238
    except AttributeError:
239
        return str(path)
240
    for pos, part in enumerate(path.parts):
241
        if not pos and part == machine_id:
242
            continue
243
        if pos == last:
244
            part = part.rsplit('.', 1)[0]
245
            # if len(part) > 16:
246
            #     part = "%.13s..." % part
247
        parts.append(part)
248
    return '/'.join(parts)
249
250
251
def load_timer(string):
252
    if "." not in string:
253
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
254
    mod, attr = string.rsplit(".", 1)
255
    if mod == 'pep418':
256
        if PY3:
257
            import time
258
            return NameWrapper(getattr(time, attr))
259
        else:
260
            from . import pep418
261
            return NameWrapper(getattr(pep418, attr))
262
    else:
263
        __import__(mod)
264
        mod = sys.modules[mod]
265
        return NameWrapper(getattr(mod, attr))
266
267
268
class RegressionCheck(object):
269
    def __init__(self, field, threshold):
270
        self.field = field
271
        self.threshold = threshold
272
273
    def fails(self, current, compared):
274
        val = self.compute(current, compared)
275
        if val > self.threshold:
276
            return "Field %r has failed %s: %.9f > %.9f" % (
277
                self.field, self.__class__.__name__, val, self.threshold
278
            )
279
280
281
class PercentageRegressionCheck(RegressionCheck):
282
    def compute(self, current, compared):
283
        val = compared[self.field]
284
        if not val:
285
            return float("inf")
286
        return current[self.field] / val * 100 - 100
287
288
289
class DifferenceRegressionCheck(RegressionCheck):
290
    def compute(self, current, compared):
291
        return current[self.field] - compared[self.field]
292
293
294
def parse_compare_fail(string,
295
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
296
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
297
                                      '0-9]+)?))$')):
298
    m = rex.match(string)
299
    if m:
300
        g = m.groupdict()
301
        if g['percentage']:
302
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
303
        elif g['difference']:
304
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
305
306
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
307
308
309
def parse_warmup(string):
310
    string = string.lower().strip()
311
    if string == "auto":
312
        return platform.python_implementation() == "PyPy"
313
    elif string in ["off", "false", "no"]:
314
        return False
315
    elif string in ["on", "true", "yes", ""]:
316
        return True
317
    else:
318
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
319
320
321
def name_formatter_short(bench):
322
    name = bench["name"]
323
    if bench["source"]:
324
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
325
    if name.startswith("test_"):
326
        name = name[5:]
327
    return name
328
329
330
def name_formatter_normal(bench):
331
    name = bench["name"]
332
    if bench["source"]:
333
        parts = bench["source"].split('/')
334
        parts[-1] = parts[-1][:12]
335
        name = "%s (%s)" % (name, '/'.join(parts))
336
    return name
337
338
339
def name_formatter_long(bench):
340
    if bench["source"]:
341
        return "%(fullname)s (%(source)s)" % bench
342
    else:
343
        return bench["fullname"]
344
345
346
NAME_FORMATTERS = {
347
    "short": name_formatter_short,
348
    "normal": name_formatter_normal,
349
    "long": name_formatter_long,
350
}
351
352
353
def parse_name_format(string):
354
    string = string.lower().strip()
355
    if string in NAME_FORMATTERS:
356
        return string
357
    else:
358
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
359
360
361
def parse_timer(string):
362
    return str(load_timer(string))
363
364
365
def parse_sort(string):
366
    string = string.lower().strip()
367
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
368
        raise argparse.ArgumentTypeError(
369
            "Unacceptable value: %r. "
370
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
371
            "'stddev', 'name', 'fullname'." % string)
372
    return string
373
374
375
def parse_columns(string):
376
    columns = [str.strip(s) for s in string.lower().split(',')]
377
    invalid = set(columns) - set(ALLOWED_COLUMNS)
378
    if invalid:
379
        # there are extra items in columns!
380
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
381
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
382
        raise argparse.ArgumentTypeError(msg)
383
    return columns
384
385
386
def parse_rounds(string):
387
    try:
388
        value = int(string)
389
    except ValueError as exc:
390
        raise argparse.ArgumentTypeError(exc)
391
    else:
392
        if value < 1:
393
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
394
        return value
395
396
397
def parse_seconds(string):
398
    try:
399
        return SecondsDecimal(string).as_string
400
    except Exception as exc:
401
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
402
403
404
def parse_save(string):
405
    if not string:
406
        raise argparse.ArgumentTypeError("Can't be empty.")
407
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
408
    if illegal:
409
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
410
    return string
411
412
413
def _parse_hosts(storage_url, netrc_file):
414
415
    # load creds from netrc file
416
    path = os.path.expanduser(netrc_file)
417
    creds = None
418
    if netrc_file and os.path.isfile(path):
419
        creds = netrc.netrc(path)
420
421
    # add creds to urls
422
    urls = []
423
    for netloc in storage_url.netloc.split(','):
424
        auth = ""
425
        if creds and '@' not in netloc:
426
            host = netloc.split(':').pop(0)
427
            res = creds.authenticators(host)
428
            if res:
429
                user, _, secret = res
430
                auth = "{user}:{secret}@".format(user=user, secret=secret)
431
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
432
                                                 netloc=netloc, auth=auth)
433
        urls.append(url)
434
    return urls
435
436
437
def parse_elasticsearch_storage(string, default_index="benchmark",
438
                                default_doctype="benchmark", netrc_file=''):
439
    storage_url = urlparse(string)
440
    hosts = _parse_hosts(storage_url, netrc_file)
441
    index = default_index
442
    doctype = default_doctype
443
    if storage_url.path and storage_url.path != "/":
444
        splitted = storage_url.path.strip("/").split("/")
445
        index = splitted[0]
446
        if len(splitted) >= 2:
447
            doctype = splitted[1]
448
    query = parse_qs(storage_url.query)
449
    try:
450
        project_name = query["project_name"][0]
451
    except KeyError:
452
        project_name = get_project_name()
453
    return hosts, index, doctype, project_name
454
455
456
def load_storage(storage, **kwargs):
457
    if "://" not in storage:
458
        storage = "file://" + storage
459
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
460
    if storage.startswith("file://"):
461
        from .storage.file import FileStorage
462
        return FileStorage(storage[len("file://"):], **kwargs)
463
    elif storage.startswith("elasticsearch+"):
464
        from .storage.elasticsearch import ElasticsearchStorage
465
        # TODO update benchmark_autosave
466
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
467
                                           netrc_file=netrc_file)
468
        return ElasticsearchStorage(*args, **kwargs)
469
    else:
470
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
471
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
472
473
474
def time_unit(value):
475
    if value < 1e-6:
476
        return "n", 1e9
477
    elif value < 1e-3:
478
        return "u", 1e6
479
    elif value < 1:
480
        return "m", 1e3
481
    else:
482
        return "", 1.
483
484
485
def operations_unit(value):
486
    if value > 1e+6:
487
        return "M", 1e-6
488
    if value > 1e+3:
489
        return "K", 1e-3
490
    return "", 1.
491
492
493
def format_time(value):
494
    unit, adjustment = time_unit(value)
495
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
496
497
498
class cached_property(object):
499
    def __init__(self, func):
500
        self.__doc__ = getattr(func, '__doc__')
501
        self.func = func
502
503
    def __get__(self, obj, cls):
504
        if obj is None:
505
            return self
506
        value = obj.__dict__[self.func.__name__] = self.func(obj)
507
        return value
508
509
510
def funcname(f):
511
    try:
512
        if isinstance(f, partial):
513
            return f.func.__name__
514
        else:
515
            return f.__name__
516
    except AttributeError:
517
        return str(f)
518
519
520
def clonefunc(f):
521
    """Deep clone the given function to create a new one.
522
523
    By default, the PyPy JIT specializes the assembler based on f.__code__:
524
    clonefunc makes sure that you will get a new function with a **different**
525
    __code__, so that PyPy will produce independent assembler. This is useful
526
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
527
    apples to apples.
528
529
    Use it with caution: if abused, this might easily produce an explosion of
530
    produced assembler.
531
532
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
533
    """
534
535
    # first of all, we clone the code object
536
    try:
537
        co = f.__code__
538
        if PY3:
539
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
540
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
541
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
542
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
543
        else:
544
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
545
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
546
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
547
        #
548
        # then, we clone the function itself, using the new co2
549
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
550
    except AttributeError:
551
        return f
552
553
554
def format_dict(obj):
555
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
556
557
558
class SafeJSONEncoder(json.JSONEncoder):
559
    def default(self, o):
560
        return "UNSERIALIZABLE[%r]" % o
561
562
563
def safe_dumps(obj, **kwargs):
564
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
565
566
567
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
568
    total = len(iterable)
569
570
    def progress_reporting_wrapper():
571
        for pos, item in enumerate(iterable):
572
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
573
            terminal_reporter.rewrite(string, black=True, bold=True)
574
            yield string, item
575
576
    return progress_reporting_wrapper()
577
578
579
def report_noprogress(iterable, *args, **kwargs):
580
    for pos, item in enumerate(iterable):
581
        yield "", item
582
583
584
def slugify(name):
585
    for c in "\/:*?<>| ":
586
        name = name.replace(c, '_').replace('__', '_')
587
    return name
588
589
590
def commonpath(paths):
591
    """Given a sequence of path names, returns the longest common sub-path."""
592
593
    if not paths:
594
        raise ValueError('commonpath() arg is an empty sequence')
595
596
    if isinstance(paths[0], bytes):
597
        sep = b'\\'
598
        altsep = b'/'
599
        curdir = b'.'
600
    else:
601
        sep = '\\'
602
        altsep = '/'
603
        curdir = '.'
604
605
    try:
606
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
607
        split_paths = [p.split(sep) for d, p in drivesplits]
608
609
        try:
610
            isabs, = set(p[:1] == sep for d, p in drivesplits)
611
        except ValueError:
612
            raise ValueError("Can't mix absolute and relative paths")
613
614
        # Check that all drive letters or UNC paths match. The check is made only
615
        # now otherwise type errors for mixing strings and bytes would not be
616
        # caught.
617
        if len(set(d for d, p in drivesplits)) != 1:
618
            raise ValueError("Paths don't have the same drive")
619
620
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
621
        common = path.split(sep)
622
        common = [c for c in common if c and c != curdir]
623
624
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
625
        s1 = min(split_paths)
626
        s2 = max(split_paths)
627
        for i, c in enumerate(s1):
628
            if c != s2[i]:
629
                common = common[:i]
630
                break
631
        else:
632
            common = common[:len(s1)]
633
634
        prefix = drive + sep if isabs else drive
635
        return prefix + sep.join(common)
636
    except (TypeError, AttributeError):
637
        genericpath._check_arg_types('commonpath', *paths)
638
        raise
639
640
641
def get_cprofile_functions(stats):
642
    """
643
    Convert pstats structure to list of sorted dicts about each function.
644
    """
645
    result = []
646
    # this assumes that you run py.test from project root dir
647
    project_dir_parent = dirname(os.getcwd())
648
649
    for function_info, run_info in stats.stats.items():
650
        file_path = function_info[0]
651
        if file_path.startswith(project_dir_parent):
652
            file_path = file_path[len(project_dir_parent):].lstrip('/')
653
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
654
655
        # if the function is recursive write number of 'total calls/primitive calls'
656
        if run_info[0] == run_info[1]:
657
            calls = str(run_info[0])
658
        else:
659
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
660
661
        result.append(dict(ncalls_recursion=calls,
662
                           ncalls=run_info[1],
663
                           tottime=run_info[2],
664
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
665
                           cumtime=run_info[3],
666
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
667
                           function_name=function_name))
668
669
    return result
670