Completed
Push — master ( 3a231d...275b8a )
by Ionel Cristian
34s
created

get_commit_info()   D

Complexity

Conditions 8

Size

Total Lines 45

Duplication

Lines 0
Ratio 0 %

Importance

Changes 5
Bugs 0 Features 0
Metric Value
cc 8
dl 0
loc 45
rs 4
c 5
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Miliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
59
60
61
class SecondsDecimal(Decimal):
62
    def __float__(self):
63
        return float(super(SecondsDecimal, self).__str__())
64
65
    def __str__(self):
66
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
67
68
    @property
69
    def as_string(self):
70
        return super(SecondsDecimal, self).__str__()
71
72
73
class NameWrapper(object):
74
    def __init__(self, target):
75
        self.target = target
76
77
    def __str__(self):
78
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
79
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
80
        return name
81
82
    def __repr__(self):
83
        return "NameWrapper(%s)" % repr(self.target)
84
85
86
def get_tag(project_name=None):
87
    info = get_commit_info(project_name)
88
    parts = [info['id'], get_current_time()]
89
    if info['dirty']:
90
        parts.append("uncommited-changes")
91
    return "_".join(parts)
92
93
94
def get_machine_id():
95
    return "%s-%s-%s-%s" % (
96
        platform.system(),
97
        platform.python_implementation(),
98
        ".".join(platform.python_version_tuple()[:2]),
99
        platform.architecture()[0]
100
    )
101
102
103
class Fallback(object):
104
    def __init__(self, *exceptions):
105
        self.functions = []
106
        self.exceptions = exceptions
107
108
    def __call__(self, *args, **kwargs):
109
        for func in self.functions:
110
            try:
111
                return func(*args, **kwargs)
112
            except self.exceptions:
113
                continue
114
        else:
115
            raise NotImplementedError("Must have an fallback implementation for %s" % self.functions[0])
116
117
    def register(self, other):
118
        self.functions.append(other)
119
        return self
120
121
122
get_project_name = Fallback(IndexError, CalledProcessError, OSError)
123
124
125
@get_project_name.register
126
def get_project_name_git():
127
    project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
128
    if isinstance(project_address, bytes) and str != bytes:
129
        project_address = project_address.decode()
130
    project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
131
    return project_name.strip()
132
133
134
@get_project_name.register
135
def get_project_name_hg():
136
    project_address = check_output(['hg', 'path', 'default'])
137
    project_address = project_address.decode()
138
    project_name = project_address.split("/")[-1]
139
    return project_name.strip()
140
141
142
@get_project_name.register
143
def get_project_name_default():
144
    return basename(os.getcwd())
145
146
147
def in_any_parent(name, path=None):
148
    prev = None
149
    if not path:
150
        path = os.getcwd()
151
    while path and prev != path and not exists(join(path, name)):
152
        prev = path
153
        path = dirname(path)
154
    return exists(join(path, name))
155
156
157
def subprocess_output(cmd):
158
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
159
160
161
def get_commit_info(project_name=None):
162
    dirty = False
163
    commit = 'unversioned'
164
    commit_time = None
165
    author_time = None
166
    project_name = project_name or get_project_name()
167
    branch = '(unknown)'
168
    try:
169
        if in_any_parent('.git'):
170
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
171
            desc = desc.split('-')
172
            if desc[-1].strip() == 'dirty':
173
                dirty = True
174
                desc.pop()
175
            commit = desc[-1].strip('g')
176
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
177
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
178
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
179
            if branch == 'HEAD':
180
                branch = '(detached head)'
181
        elif in_any_parent('.hg'):
182
            desc = subprocess_output('hg id --id --debug')
183
            if desc[-1] == '+':
184
                dirty = True
185
            commit = desc.strip('+')
186
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
187
            branch = subprocess_output('hg branch')
188
        return {
189
            'id': commit,
190
            'time': commit_time,
191
            'author_time': author_time,
192
            'dirty': dirty,
193
            'project': project_name,
194
            'branch': branch,
195
        }
196
    except Exception as exc:
197
        return {
198
            'id': 'unknown',
199
            'time': None,
200
            'author_time': None,
201
            'dirty': dirty,
202
            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
203
                     if isinstance(exc, CalledProcessError) else repr(exc),
204
            'project': project_name,
205
            'branch': branch,
206
        }
207
208
209
def get_current_time():
210
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
211
212
213
def first_or_value(obj, value):
214
    if obj:
215
        value, = obj
216
217
    return value
218
219
220
def short_filename(path, machine_id=None):
221
    parts = []
222
    try:
223
        last = len(path.parts) - 1
224
    except AttributeError:
225
        return str(path)
226
    for pos, part in enumerate(path.parts):
227
        if not pos and part == machine_id:
228
            continue
229
        if pos == last:
230
            part = part.rsplit('.', 1)[0]
231
            # if len(part) > 16:
232
            #     part = "%.13s..." % part
233
        parts.append(part)
234
    return '/'.join(parts)
235
236
237
def load_timer(string):
238
    if "." not in string:
239
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
240
    mod, attr = string.rsplit(".", 1)
241
    if mod == 'pep418':
242
        if PY3:
243
            import time
244
            return NameWrapper(getattr(time, attr))
245
        else:
246
            from . import pep418
247
            return NameWrapper(getattr(pep418, attr))
248
    else:
249
        __import__(mod)
250
        mod = sys.modules[mod]
251
        return NameWrapper(getattr(mod, attr))
252
253
254
class RegressionCheck(object):
255
    def __init__(self, field, threshold):
256
        self.field = field
257
        self.threshold = threshold
258
259
    def fails(self, current, compared):
260
        val = self.compute(current, compared)
261
        if val > self.threshold:
262
            return "Field %r has failed %s: %.9f > %.9f" % (
263
                self.field, self.__class__.__name__, val, self.threshold
264
            )
265
266
267
class PercentageRegressionCheck(RegressionCheck):
268
    def compute(self, current, compared):
269
        val = compared[self.field]
270
        if not val:
271
            return float("inf")
272
        return current[self.field] / val * 100 - 100
273
274
275
class DifferenceRegressionCheck(RegressionCheck):
276
    def compute(self, current, compared):
277
        return current[self.field] - compared[self.field]
278
279
280
def parse_compare_fail(string,
281
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
282
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
283
                                      '0-9]+)?))$')):
284
    m = rex.match(string)
285
    if m:
286
        g = m.groupdict()
287
        if g['percentage']:
288
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
289
        elif g['difference']:
290
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
291
292
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
293
294
295
def parse_warmup(string):
296
    string = string.lower().strip()
297
    if string == "auto":
298
        return platform.python_implementation() == "PyPy"
299
    elif string in ["off", "false", "no"]:
300
        return False
301
    elif string in ["on", "true", "yes", ""]:
302
        return True
303
    else:
304
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
305
306
307
def name_formatter_short(bench):
308
    name = bench["name"]
309
    if bench["source"]:
310
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
311
    if name.startswith("test_"):
312
        name = name[5:]
313
    return name
314
315
316
def name_formatter_normal(bench):
317
    name = bench["name"]
318
    if bench["source"]:
319
        parts = bench["source"].split('/')
320
        parts[-1] = parts[-1][:12]
321
        name = "%s (%s)" % (name, '/'.join(parts))
322
    return name
323
324
325
def name_formatter_long(bench):
326
    if bench["source"]:
327
        return "%(fullname)s (%(source)s)" % bench
328
    else:
329
        return bench["fullname"]
330
331
332
NAME_FORMATTERS = {
333
    "short": name_formatter_short,
334
    "normal": name_formatter_normal,
335
    "long": name_formatter_long,
336
}
337
338
339
def parse_name_format(string):
340
    string = string.lower().strip()
341
    if string in NAME_FORMATTERS:
342
        return string
343
    else:
344
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
345
346
347
def parse_timer(string):
348
    return str(load_timer(string))
349
350
351
def parse_sort(string):
352
    string = string.lower().strip()
353
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
354
        raise argparse.ArgumentTypeError(
355
            "Unacceptable value: %r. "
356
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
357
            "'stddev', 'name', 'fullname'." % string)
358
    return string
359
360
361
def parse_columns(string):
362
    columns = [str.strip(s) for s in string.lower().split(',')]
363
    invalid = set(columns) - set(ALLOWED_COLUMNS)
364
    if invalid:
365
        # there are extra items in columns!
366
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
367
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
368
        raise argparse.ArgumentTypeError(msg)
369
    return columns
370
371
372
def parse_rounds(string):
373
    try:
374
        value = int(string)
375
    except ValueError as exc:
376
        raise argparse.ArgumentTypeError(exc)
377
    else:
378
        if value < 1:
379
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
380
        return value
381
382
383
def parse_seconds(string):
384
    try:
385
        return SecondsDecimal(string).as_string
386
    except Exception as exc:
387
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
388
389
390
def parse_save(string):
391
    if not string:
392
        raise argparse.ArgumentTypeError("Can't be empty.")
393
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
394
    if illegal:
395
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
396
    return string
397
398
399
def _parse_hosts(storage_url, netrc_file):
400
401
    # load creds from netrc file
402
    path = os.path.expanduser(netrc_file)
403
    creds = None
404
    if netrc_file and os.path.isfile(path):
405
        creds = netrc.netrc(path)
406
407
    # add creds to urls
408
    urls = []
409
    for netloc in storage_url.netloc.split(','):
410
        auth = ""
411
        if creds and '@' not in netloc:
412
            host = netloc.split(':').pop(0)
413
            res = creds.authenticators(host)
414
            if res:
415
                user, _, secret = res
416
                auth = "{user}:{secret}@".format(user=user, secret=secret)
417
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
418
                                                 netloc=netloc, auth=auth)
419
        urls.append(url)
420
    return urls
421
422
423
def parse_elasticsearch_storage(string, default_index="benchmark",
424
                                default_doctype="benchmark", netrc_file=''):
425
    storage_url = urlparse(string)
426
    hosts = _parse_hosts(storage_url, netrc_file)
427
    index = default_index
428
    doctype = default_doctype
429
    if storage_url.path and storage_url.path != "/":
430
        splitted = storage_url.path.strip("/").split("/")
431
        index = splitted[0]
432
        if len(splitted) >= 2:
433
            doctype = splitted[1]
434
    query = parse_qs(storage_url.query)
435
    try:
436
        project_name = query["project_name"][0]
437
    except KeyError:
438
        project_name = get_project_name()
439
    return hosts, index, doctype, project_name
440
441
442
def load_storage(storage, **kwargs):
443
    if "://" not in storage:
444
        storage = "file://" + storage
445
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
446
    if storage.startswith("file://"):
447
        from .storage.file import FileStorage
448
        return FileStorage(storage[len("file://"):], **kwargs)
449
    elif storage.startswith("elasticsearch+"):
450
        from .storage.elasticsearch import ElasticsearchStorage
451
        # TODO update benchmark_autosave
452
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
453
                                           netrc_file=netrc_file)
454
        return ElasticsearchStorage(*args, **kwargs)
455
    else:
456
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
457
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
458
459
460
def time_unit(value):
461
    if value < 1e-6:
462
        return "n", 1e9
463
    elif value < 1e-3:
464
        return "u", 1e6
465
    elif value < 1:
466
        return "m", 1e3
467
    else:
468
        return "", 1.
469
470
471
def operations_unit(value):
472
    if value > 1e+6:
473
        return "M", 1e-6
474
    if value > 1e+3:
475
        return "K", 1e-3
476
    return "", 1.
477
478
479
def format_time(value):
480
    unit, adjustment = time_unit(value)
481
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
482
483
484
class cached_property(object):
485
    def __init__(self, func):
486
        self.__doc__ = getattr(func, '__doc__')
487
        self.func = func
488
489
    def __get__(self, obj, cls):
490
        if obj is None:
491
            return self
492
        value = obj.__dict__[self.func.__name__] = self.func(obj)
493
        return value
494
495
496
def funcname(f):
497
    try:
498
        if isinstance(f, partial):
499
            return f.func.__name__
500
        else:
501
            return f.__name__
502
    except AttributeError:
503
        return str(f)
504
505
506
def clonefunc(f):
507
    """Deep clone the given function to create a new one.
508
509
    By default, the PyPy JIT specializes the assembler based on f.__code__:
510
    clonefunc makes sure that you will get a new function with a **different**
511
    __code__, so that PyPy will produce independent assembler. This is useful
512
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
513
    apples to apples.
514
515
    Use it with caution: if abused, this might easily produce an explosion of
516
    produced assembler.
517
518
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
519
    """
520
521
    # first of all, we clone the code object
522
    try:
523
        co = f.__code__
524
        if PY3:
525
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
526
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
527
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
528
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
529
        else:
530
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
531
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
532
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
533
        #
534
        # then, we clone the function itself, using the new co2
535
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
536
    except AttributeError:
537
        return f
538
539
540
def format_dict(obj):
541
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
542
543
544
class SafeJSONEncoder(json.JSONEncoder):
545
    def default(self, o):
546
        return "UNSERIALIZABLE[%r]" % o
547
548
549
def safe_dumps(obj, **kwargs):
550
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
551
552
553
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
554
    total = len(iterable)
555
556
    def progress_reporting_wrapper():
557
        for pos, item in enumerate(iterable):
558
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
559
            terminal_reporter.rewrite(string, black=True, bold=True)
560
            yield string, item
561
562
    return progress_reporting_wrapper()
563
564
565
def report_noprogress(iterable, *args, **kwargs):
566
    for pos, item in enumerate(iterable):
567
        yield "", item
568
569
570
def slugify(name):
571
    for c in "\/:*?<>| ":
572
        name = name.replace(c, '_').replace('__', '_')
573
    return name
574
575
576
def commonpath(paths):
577
    """Given a sequence of path names, returns the longest common sub-path."""
578
579
    if not paths:
580
        raise ValueError('commonpath() arg is an empty sequence')
581
582
    if isinstance(paths[0], bytes):
583
        sep = b'\\'
584
        altsep = b'/'
585
        curdir = b'.'
586
    else:
587
        sep = '\\'
588
        altsep = '/'
589
        curdir = '.'
590
591
    try:
592
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
593
        split_paths = [p.split(sep) for d, p in drivesplits]
594
595
        try:
596
            isabs, = set(p[:1] == sep for d, p in drivesplits)
597
        except ValueError:
598
            raise ValueError("Can't mix absolute and relative paths")
599
600
        # Check that all drive letters or UNC paths match. The check is made only
601
        # now otherwise type errors for mixing strings and bytes would not be
602
        # caught.
603
        if len(set(d for d, p in drivesplits)) != 1:
604
            raise ValueError("Paths don't have the same drive")
605
606
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
607
        common = path.split(sep)
608
        common = [c for c in common if c and c != curdir]
609
610
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
611
        s1 = min(split_paths)
612
        s2 = max(split_paths)
613
        for i, c in enumerate(s1):
614
            if c != s2[i]:
615
                common = common[:i]
616
                break
617
        else:
618
            common = common[:len(s1)]
619
620
        prefix = drive + sep if isabs else drive
621
        return prefix + sep.join(common)
622
    except (TypeError, AttributeError):
623
        genericpath._check_arg_types('commonpath', *paths)
624
        raise
625
626
627
def get_cprofile_functions(stats):
628
    """
629
    Convert pstats structure to list of sorted dicts about each function.
630
    """
631
    result = []
632
    # this assumes that you run py.test from project root dir
633
    project_dir_parent = dirname(os.getcwd())
634
635
    for function_info, run_info in stats.stats.items():
636
        file_path = function_info[0]
637
        if file_path.startswith(project_dir_parent):
638
            file_path = file_path[len(project_dir_parent):].lstrip('/')
639
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
640
641
        # if the function is recursive write number of 'total calls/primitive calls'
642
        if run_info[0] == run_info[1]:
643
            calls = str(run_info[0])
644
        else:
645
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
646
647
        result.append(dict(ncalls_recursion=calls,
648
                           ncalls=run_info[1],
649
                           tottime=run_info[2],
650
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
651
                           cumtime=run_info[3],
652
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
653
                           function_name=function_name))
654
655
    return result
656