Completed
Pull Request — master (#114)
by
unknown
31s
created

get_commit_info()   C

Complexity

Conditions 8

Size

Total Lines 45

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 8
dl 0
loc 45
rs 6.9333
c 5
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
# This is here (in the utils module) because it might be used by
27
# various other modules.
28
try:
29
    from pathlib2 import Path   # noqa: F401
30
except ImportError:
31
    from pathlib import Path    # noqa: F401
32
33
try:
34
    from urllib.parse import urlparse, parse_qs
35
except ImportError:
36
    from urlparse import urlparse, parse_qs
37
38
try:
39
    from subprocess import check_output, CalledProcessError
40
except ImportError:
41
    class CalledProcessError(subprocess.CalledProcessError):
42
        def __init__(self, returncode, cmd, output=None):
43
            super(CalledProcessError, self).__init__(returncode, cmd)
44
            self.output = output
45
46
    def check_output(*popenargs, **kwargs):
47
        if 'stdout' in kwargs:
48
            raise ValueError('stdout argument not allowed, it will be overridden.')
49
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
50
        output, unused_err = process.communicate()
51
        retcode = process.poll()
52
        if retcode:
53
            cmd = kwargs.get("args")
54
            if cmd is None:
55
                cmd = popenargs[0]
56
            raise CalledProcessError(retcode, cmd, output)
57
        return output
58
59
TIME_UNITS = {
60
    "": "Seconds",
61
    "m": "Milliseconds (ms)",
62
    "u": "Microseconds (us)",
63
    "n": "Nanoseconds (ns)"
64
}
65
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
66
67
68
class SecondsDecimal(Decimal):
69
    def __float__(self):
70
        return float(super(SecondsDecimal, self).__str__())
71
72
    def __str__(self):
73
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
74
75
    @property
76
    def as_string(self):
77
        return super(SecondsDecimal, self).__str__()
78
79
80
class NameWrapper(object):
81
    def __init__(self, target):
82
        self.target = target
83
84
    def __str__(self):
85
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
86
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
87
        return name
88
89
    def __repr__(self):
90
        return "NameWrapper(%s)" % repr(self.target)
91
92
93
def get_tag(project_name=None):
94
    info = get_commit_info(project_name)
95
    parts = [info['id'], get_current_time()]
96
    if info['dirty']:
97
        parts.append("uncommited-changes")
98
    return "_".join(parts)
99
100
101
def get_machine_id():
102
    return "%s-%s-%s-%s" % (
103
        platform.system(),
104
        platform.python_implementation(),
105
        ".".join(platform.python_version_tuple()[:2]),
106
        platform.architecture()[0]
107
    )
108
109
110
class Fallback(object):
111
    def __init__(self, fallback, exceptions):
112
        self.fallback = fallback
113
        self.functions = []
114
        self.exceptions = exceptions
115
116
    def __call__(self, *args, **kwargs):
117
        for func in self.functions:
118
            try:
119
                value = func(*args, **kwargs)
120
            except self.exceptions:
121
                continue
122
            else:
123
                if not value:
124
                    continue
125
        else:
126
            return self.fallback(*args, **kwargs)
127
128
    def register(self, other):
129
        self.functions.append(other)
130
        return self
131
132
133
@partial(Fallback, exceptions=(IndexError, CalledProcessError, OSError))
134
def get_project_name():
135
    return basename(os.getcwd())
136
137
138
@get_project_name.register
139
def get_project_name_git():
140
    is_git = check_output(['git', 'rev-parse', '--git-dir'], stderr=subprocess.STDOUT)
141
    if is_git:
142
        project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
143
        if isinstance(project_address, bytes) and str != bytes:
144
            project_address = project_address.decode()
145
        project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
146
        return project_name.strip()
147
148
149
@get_project_name.register
150
def get_project_name_hg():
151
    project_address = check_output(['hg', 'path', 'default', '--quiet'])
152
    project_address = project_address.decode()
153
    project_name = project_address.split("/")[-1]
154
    return project_name.strip()
155
156
157
def in_any_parent(name, path=None):
158
    prev = None
159
    if not path:
160
        path = os.getcwd()
161
    while path and prev != path and not exists(join(path, name)):
162
        prev = path
163
        path = dirname(path)
164
    return exists(join(path, name))
165
166
167
def subprocess_output(cmd):
168
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
169
170
171
def get_commit_info(project_name=None):
172
    dirty = False
173
    commit = 'unversioned'
174
    commit_time = None
175
    author_time = None
176
    project_name = project_name or get_project_name()
177
    branch = '(unknown)'
178
    try:
179
        if in_any_parent('.git'):
180
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
181
            desc = desc.split('-')
182
            if desc[-1].strip() == 'dirty':
183
                dirty = True
184
                desc.pop()
185
            commit = desc[-1].strip('g')
186
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
187
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
188
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
189
            if branch == 'HEAD':
190
                branch = '(detached head)'
191
        elif in_any_parent('.hg'):
192
            desc = subprocess_output('hg id --id --debug')
193
            if desc[-1] == '+':
194
                dirty = True
195
            commit = desc.strip('+')
196
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
197
            branch = subprocess_output('hg branch')
198
        return {
199
            'id': commit,
200
            'time': commit_time,
201
            'author_time': author_time,
202
            'dirty': dirty,
203
            'project': project_name,
204
            'branch': branch,
205
        }
206
    except Exception as exc:
207
        return {
208
            'id': 'unknown',
209
            'time': None,
210
            'author_time': None,
211
            'dirty': dirty,
212
            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
213
                     if isinstance(exc, CalledProcessError) else repr(exc),
214
            'project': project_name,
215
            'branch': branch,
216
        }
217
218
219
def get_current_time():
220
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
221
222
223
def first_or_value(obj, value):
224
    if obj:
225
        value, = obj
226
227
    return value
228
229
230
def short_filename(path, machine_id=None):
231
    parts = []
232
    try:
233
        last = len(path.parts) - 1
234
    except AttributeError:
235
        return str(path)
236
    for pos, part in enumerate(path.parts):
237
        if not pos and part == machine_id:
238
            continue
239
        if pos == last:
240
            part = part.rsplit('.', 1)[0]
241
            # if len(part) > 16:
242
            #     part = "%.13s..." % part
243
        parts.append(part)
244
    return '/'.join(parts)
245
246
247
def load_timer(string):
248
    if "." not in string:
249
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
250
    mod, attr = string.rsplit(".", 1)
251
    if mod == 'pep418':
252
        if PY3:
253
            import time
254
            return NameWrapper(getattr(time, attr))
255
        else:
256
            from . import pep418
257
            return NameWrapper(getattr(pep418, attr))
258
    else:
259
        __import__(mod)
260
        mod = sys.modules[mod]
261
        return NameWrapper(getattr(mod, attr))
262
263
264
class RegressionCheck(object):
265
    def __init__(self, field, threshold):
266
        self.field = field
267
        self.threshold = threshold
268
269
    def fails(self, current, compared):
270
        val = self.compute(current, compared)
271
        if val > self.threshold:
272
            return "Field %r has failed %s: %.9f > %.9f" % (
273
                self.field, self.__class__.__name__, val, self.threshold
274
            )
275
276
277
class PercentageRegressionCheck(RegressionCheck):
278
    def compute(self, current, compared):
279
        val = compared[self.field]
280
        if not val:
281
            return float("inf")
282
        return current[self.field] / val * 100 - 100
283
284
285
class DifferenceRegressionCheck(RegressionCheck):
286
    def compute(self, current, compared):
287
        return current[self.field] - compared[self.field]
288
289
290
def parse_compare_fail(string,
291
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
292
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
293
                                      '0-9]+)?))$')):
294
    m = rex.match(string)
295
    if m:
296
        g = m.groupdict()
297
        if g['percentage']:
298
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
299
        elif g['difference']:
300
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
301
302
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
303
304
305
def parse_warmup(string):
306
    string = string.lower().strip()
307
    if string == "auto":
308
        return platform.python_implementation() == "PyPy"
309
    elif string in ["off", "false", "no"]:
310
        return False
311
    elif string in ["on", "true", "yes", ""]:
312
        return True
313
    else:
314
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
315
316
317
def name_formatter_short(bench):
318
    name = bench["name"]
319
    if bench["source"]:
320
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
321
    if name.startswith("test_"):
322
        name = name[5:]
323
    return name
324
325
326
def name_formatter_normal(bench):
327
    name = bench["name"]
328
    if bench["source"]:
329
        parts = bench["source"].split('/')
330
        parts[-1] = parts[-1][:12]
331
        name = "%s (%s)" % (name, '/'.join(parts))
332
    return name
333
334
335
def name_formatter_long(bench):
336
    if bench["source"]:
337
        return "%(fullname)s (%(source)s)" % bench
338
    else:
339
        return bench["fullname"]
340
341
342
NAME_FORMATTERS = {
343
    "short": name_formatter_short,
344
    "normal": name_formatter_normal,
345
    "long": name_formatter_long,
346
}
347
348
349
def parse_name_format(string):
350
    string = string.lower().strip()
351
    if string in NAME_FORMATTERS:
352
        return string
353
    else:
354
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
355
356
357
def parse_timer(string):
358
    return str(load_timer(string))
359
360
361
def parse_sort(string):
362
    string = string.lower().strip()
363
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
364
        raise argparse.ArgumentTypeError(
365
            "Unacceptable value: %r. "
366
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
367
            "'stddev', 'name', 'fullname'." % string)
368
    return string
369
370
371
def parse_columns(string):
372
    columns = [str.strip(s) for s in string.lower().split(',')]
373
    invalid = set(columns) - set(ALLOWED_COLUMNS)
374
    if invalid:
375
        # there are extra items in columns!
376
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
377
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
378
        raise argparse.ArgumentTypeError(msg)
379
    return columns
380
381
382
def parse_rounds(string):
383
    try:
384
        value = int(string)
385
    except ValueError as exc:
386
        raise argparse.ArgumentTypeError(exc)
387
    else:
388
        if value < 1:
389
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
390
        return value
391
392
393
def parse_seconds(string):
394
    try:
395
        return SecondsDecimal(string).as_string
396
    except Exception as exc:
397
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
398
399
400
def parse_save(string):
401
    if not string:
402
        raise argparse.ArgumentTypeError("Can't be empty.")
403
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
404
    if illegal:
405
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
406
    return string
407
408
409
def _parse_hosts(storage_url, netrc_file):
410
411
    # load creds from netrc file
412
    path = os.path.expanduser(netrc_file)
413
    creds = None
414
    if netrc_file and os.path.isfile(path):
415
        creds = netrc.netrc(path)
416
417
    # add creds to urls
418
    urls = []
419
    for netloc in storage_url.netloc.split(','):
420
        auth = ""
421
        if creds and '@' not in netloc:
422
            host = netloc.split(':').pop(0)
423
            res = creds.authenticators(host)
424
            if res:
425
                user, _, secret = res
426
                auth = "{user}:{secret}@".format(user=user, secret=secret)
427
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
428
                                                 netloc=netloc, auth=auth)
429
        urls.append(url)
430
    return urls
431
432
433
def parse_elasticsearch_storage(string, default_index="benchmark",
434
                                default_doctype="benchmark", netrc_file=''):
435
    storage_url = urlparse(string)
436
    hosts = _parse_hosts(storage_url, netrc_file)
437
    index = default_index
438
    doctype = default_doctype
439
    if storage_url.path and storage_url.path != "/":
440
        splitted = storage_url.path.strip("/").split("/")
441
        index = splitted[0]
442
        if len(splitted) >= 2:
443
            doctype = splitted[1]
444
    query = parse_qs(storage_url.query)
445
    try:
446
        project_name = query["project_name"][0]
447
    except KeyError:
448
        project_name = get_project_name()
449
    return hosts, index, doctype, project_name
450
451
452
def load_storage(storage, **kwargs):
453
    if "://" not in storage:
454
        storage = "file://" + storage
455
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
456
    if storage.startswith("file://"):
457
        from .storage.file import FileStorage
458
        return FileStorage(storage[len("file://"):], **kwargs)
459
    elif storage.startswith("elasticsearch+"):
460
        from .storage.elasticsearch import ElasticsearchStorage
461
        # TODO update benchmark_autosave
462
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
463
                                           netrc_file=netrc_file)
464
        return ElasticsearchStorage(*args, **kwargs)
465
    else:
466
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
467
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
468
469
470
def time_unit(value):
471
    if value < 1e-6:
472
        return "n", 1e9
473
    elif value < 1e-3:
474
        return "u", 1e6
475
    elif value < 1:
476
        return "m", 1e3
477
    else:
478
        return "", 1.
479
480
481
def operations_unit(value):
482
    if value > 1e+6:
483
        return "M", 1e-6
484
    if value > 1e+3:
485
        return "K", 1e-3
486
    return "", 1.
487
488
489
def format_time(value):
490
    unit, adjustment = time_unit(value)
491
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
492
493
494
class cached_property(object):
495
    def __init__(self, func):
496
        self.__doc__ = getattr(func, '__doc__')
497
        self.func = func
498
499
    def __get__(self, obj, cls):
500
        if obj is None:
501
            return self
502
        value = obj.__dict__[self.func.__name__] = self.func(obj)
503
        return value
504
505
506
def funcname(f):
507
    try:
508
        if isinstance(f, partial):
509
            return f.func.__name__
510
        else:
511
            return f.__name__
512
    except AttributeError:
513
        return str(f)
514
515
516
def clonefunc(f):
517
    """Deep clone the given function to create a new one.
518
519
    By default, the PyPy JIT specializes the assembler based on f.__code__:
520
    clonefunc makes sure that you will get a new function with a **different**
521
    __code__, so that PyPy will produce independent assembler. This is useful
522
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
523
    apples to apples.
524
525
    Use it with caution: if abused, this might easily produce an explosion of
526
    produced assembler.
527
528
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
529
    """
530
531
    # first of all, we clone the code object
532
    try:
533
        co = f.__code__
534
        if PY3:
535
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
536
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
537
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
538
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
539
        else:
540
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
541
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
542
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
543
        #
544
        # then, we clone the function itself, using the new co2
545
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
546
    except AttributeError:
547
        return f
548
549
550
def format_dict(obj):
551
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
552
553
554
class SafeJSONEncoder(json.JSONEncoder):
555
    def default(self, o):
556
        return "UNSERIALIZABLE[%r]" % o
557
558
559
def safe_dumps(obj, **kwargs):
560
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
561
562
563
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
564
    total = len(iterable)
565
566
    def progress_reporting_wrapper():
567
        for pos, item in enumerate(iterable):
568
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
569
            terminal_reporter.rewrite(string, black=True, bold=True)
570
            yield string, item
571
572
    return progress_reporting_wrapper()
573
574
575
def report_noprogress(iterable, *args, **kwargs):
576
    for pos, item in enumerate(iterable):
577
        yield "", item
578
579
580
def slugify(name):
581
    for c in "\/:*?<>| ":
582
        name = name.replace(c, '_').replace('__', '_')
583
    return name
584
585
586
def commonpath(paths):
587
    """Given a sequence of path names, returns the longest common sub-path."""
588
589
    if not paths:
590
        raise ValueError('commonpath() arg is an empty sequence')
591
592
    if isinstance(paths[0], bytes):
593
        sep = b'\\'
594
        altsep = b'/'
595
        curdir = b'.'
596
    else:
597
        sep = '\\'
598
        altsep = '/'
599
        curdir = '.'
600
601
    try:
602
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
603
        split_paths = [p.split(sep) for d, p in drivesplits]
604
605
        try:
606
            isabs, = set(p[:1] == sep for d, p in drivesplits)
607
        except ValueError:
608
            raise ValueError("Can't mix absolute and relative paths")
609
610
        # Check that all drive letters or UNC paths match. The check is made only
611
        # now otherwise type errors for mixing strings and bytes would not be
612
        # caught.
613
        if len(set(d for d, p in drivesplits)) != 1:
614
            raise ValueError("Paths don't have the same drive")
615
616
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
617
        common = path.split(sep)
618
        common = [c for c in common if c and c != curdir]
619
620
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
621
        s1 = min(split_paths)
622
        s2 = max(split_paths)
623
        for i, c in enumerate(s1):
624
            if c != s2[i]:
625
                common = common[:i]
626
                break
627
        else:
628
            common = common[:len(s1)]
629
630
        prefix = drive + sep if isabs else drive
631
        return prefix + sep.join(common)
632
    except (TypeError, AttributeError):
633
        genericpath._check_arg_types('commonpath', *paths)
634
        raise
635
636
637
def get_cprofile_functions(stats):
638
    """
639
    Convert pstats structure to list of sorted dicts about each function.
640
    """
641
    result = []
642
    # this assumes that you run py.test from project root dir
643
    project_dir_parent = dirname(os.getcwd())
644
645
    for function_info, run_info in stats.stats.items():
646
        file_path = function_info[0]
647
        if file_path.startswith(project_dir_parent):
648
            file_path = file_path[len(project_dir_parent):].lstrip('/')
649
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
650
651
        # if the function is recursive write number of 'total calls/primitive calls'
652
        if run_info[0] == run_info[1]:
653
            calls = str(run_info[0])
654
        else:
655
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
656
657
        result.append(dict(ncalls_recursion=calls,
658
                           ncalls=run_info[1],
659
                           tottime=run_info[2],
660
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
661
                           cumtime=run_info[3],
662
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
663
                           function_name=function_name))
664
665
    return result
666