Completed
Pull Request — master (#64)
by
unknown
01:01
created

get_branch_info()   B

Complexity

Conditions 5

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
dl 0
loc 15
rs 8.5454
c 0
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import ntpath
8
import os
9
import platform
10
import re
11
import subprocess
12
import sys
13
import types
14
from datetime import datetime
15
from decimal import Decimal
16
from functools import partial
17
18
from .compat import PY3
19
20
try:
21
    from urllib.parse import urlparse, parse_qs
22
except ImportError:
23
    from urlparse import urlparse, parse_qs
24
25
try:
26
    from subprocess import check_output
27
except ImportError:
28
    def check_output(*popenargs, **kwargs):
29
        if 'stdout' in kwargs:
30
            raise ValueError('stdout argument not allowed, it will be overridden.')
31
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
32
        output, unused_err = process.communicate()
33
        retcode = process.poll()
34
        if retcode:
35
            cmd = kwargs.get("args")
36
            if cmd is None:
37
                cmd = popenargs[0]
38
            raise subprocess.CalledProcessError(retcode, cmd)
39
        return output
40
41
TIME_UNITS = {
42
    "": "Seconds",
43
    "m": "Miliseconds (ms)",
44
    "u": "Microseconds (us)",
45
    "n": "Nanoseconds (ns)"
46
}
47
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "rounds", "iterations"]
48
49
50
class SecondsDecimal(Decimal):
51
    def __float__(self):
52
        return float(super(SecondsDecimal, self).__str__())
53
54
    def __str__(self):
55
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
56
57
    @property
58
    def as_string(self):
59
        return super(SecondsDecimal, self).__str__()
60
61
62
class NameWrapper(object):
63
    def __init__(self, target):
64
        self.target = target
65
66
    def __str__(self):
67
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
68
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
69
        return name
70
71
    def __repr__(self):
72
        return "NameWrapper(%s)" % repr(self.target)
73
74
75
def get_tag(project_name=None):
76
    info = get_commit_info(project_name)
77
    parts = []
78
    if info['project']:
79
        parts.append(info['project'])
80
    parts.append(info['id'])
81
    parts.append(get_current_time())
82
    if info['dirty']:
83
        parts.append("uncommited-changes")
84
    return "_".join(parts)
85
86
87
def get_machine_id():
88
    return "%s-%s-%s-%s" % (
89
        platform.system(),
90
        platform.python_implementation(),
91
        ".".join(platform.python_version_tuple()[:2]),
92
        platform.architecture()[0]
93
    )
94
95
96
def get_project_name():
97
    if os.path.exists('.git'):
98
        try:
99
            project_address = check_output("git config --local remote.origin.url".split())
100
            if isinstance(project_address, bytes) and str != bytes:
101
                project_address = project_address.decode()
102
            project_name = re.findall(r'/([^/]*)\.git', project_address)[0]
103
            return project_name
104
        except (IndexError, subprocess.CalledProcessError):
105
            return os.path.basename(os.getcwd())
106
    elif os.path.exists('.hg'):
107
        try:
108
            project_address = check_output("hg path default".split())
109
            project_address = project_address.decode()
110
            project_name = project_address.split("/")[-1]
111
            return project_name.strip()
112
        except (IndexError, subprocess.CalledProcessError):
113
            return os.path.basename(os.getcwd())
114
    else:
115
        return os.path.basename(os.getcwd())
116
117
def get_branch_info():
118
    if os.path.exists('.git'):
119
        try:
120
            branch = check_output('git symbolic-ref --short -q HEAD'.split())
121
            return branch.strip()
122
        except subprocess.CalledProcessError:
123
            return '(detached head)'
124
    elif os.path.exists('.hg'):
125
        try:
126
            branch = check_output('hg branch'.split())
127
            return branch.strip()
128
        except subprocess.CalledProcessError:
129
            return 'unknown'
130
    else:
131
        return 'unknown'
132
133
def get_commit_info(project_name=None):
134
    dirty = False
135
    commit = 'unversioned'
136
    project_name = project_name or get_project_name()
137
    branch = get_branch_info()
138
    try:
139
        if os.path.exists('.git'):
140
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
141
                                universal_newlines=True).strip()
142
            desc = desc.split('-')
143
            if desc[-1].strip() == 'dirty':
144
                dirty = True
145
                desc.pop()
146
            commit = desc[-1].strip('g')
147
        elif os.path.exists('.hg'):
148
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
149
            if desc[-1] == '+':
150
                dirty = True
151
            commit = desc.strip('+')
152
        return {
153
            'id': commit,
154
            'dirty': dirty,
155
            'project': project_name,
156
            'branch': branch,
157
        }
158
    except Exception as exc:
159
        return {
160
            'id': 'unknown',
161
            'dirty': dirty,
162
            'error': repr(exc),
163
            'project': project_name,
164
            'branch': branch,
165
        }
166
167
168
def get_current_time():
169
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
170
171
172
def first_or_value(obj, value):
173
    if obj:
174
        value, = obj
175
176
    return value
177
178
179
def short_filename(path, machine_id=None):
180
    parts = []
181
    try:
182
        last = len(path.parts) - 1
183
    except AttributeError:
184
        return str(path)
185
    for pos, part in enumerate(path.parts):
186
        if not pos and part == machine_id:
187
            continue
188
        if pos == last:
189
            part = part.rsplit('.', 1)[0]
190
            # if len(part) > 16:
191
            #     part = "%.13s..." % part
192
        parts.append(part)
193
    return '/'.join(parts)
194
195
196
def load_timer(string):
197
    if "." not in string:
198
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
199
    mod, attr = string.rsplit(".", 1)
200
    if mod == 'pep418':
201
        if PY3:
202
            import time
203
            return NameWrapper(getattr(time, attr))
204
        else:
205
            from . import pep418
206
            return NameWrapper(getattr(pep418, attr))
207
    else:
208
        __import__(mod)
209
        mod = sys.modules[mod]
210
        return NameWrapper(getattr(mod, attr))
211
212
213
class RegressionCheck(object):
214
    def __init__(self, field, threshold):
215
        self.field = field
216
        self.threshold = threshold
217
218
    def fails(self, current, compared):
219
        val = self.compute(current, compared)
220
        if val > self.threshold:
221
            return "Field %r has failed %s: %.9f > %.9f" % (
222
                self.field, self.__class__.__name__, val, self.threshold
223
            )
224
225
226
class PercentageRegressionCheck(RegressionCheck):
227
    def compute(self, current, compared):
228
        val = compared[self.field]
229
        if not val:
230
            return float("inf")
231
        return current[self.field] / val * 100 - 100
232
233
234
class DifferenceRegressionCheck(RegressionCheck):
235
    def compute(self, current, compared):
236
        return current[self.field] - compared[self.field]
237
238
239
def parse_compare_fail(string,
240
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
241
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?))$')):
242
    m = rex.match(string)
243
    if m:
244
        g = m.groupdict()
245
        if g['percentage']:
246
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
247
        elif g['difference']:
248
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
249
250
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
251
252
253
def parse_warmup(string):
254
    string = string.lower().strip()
255
    if string == "auto":
256
        return platform.python_implementation() == "PyPy"
257
    elif string in ["off", "false", "no"]:
258
        return False
259
    elif string in ["on", "true", "yes", ""]:
260
        return True
261
    else:
262
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
263
264
265
def name_formatter_short(bench):
266
    name = bench["name"]
267
    if bench["source"]:
268
        name = "%s (%.4s)" % (name, os.path.split(bench["source"])[-1])
269
    if name.startswith("test_"):
270
        name = name[5:]
271
    return name
272
273
274
def name_formatter_normal(bench):
275
    name = bench["name"]
276
    if bench["source"]:
277
        parts = bench["source"].split('/')
278
        parts[-1] = parts[-1][:12]
279
        name = "%s (%s)" % (name, '/'.join(parts))
280
    return name
281
282
283
def name_formatter_long(bench):
284
    if bench["source"]:
285
        return "%(fullname)s (%(source)s)" % bench
286
    else:
287
        return bench["fullname"]
288
289
290
NAME_FORMATTERS = {
291
    "short": name_formatter_short,
292
    "normal": name_formatter_normal,
293
    "long": name_formatter_long,
294
}
295
296
297
def parse_name_format(string):
298
    string = string.lower().strip()
299
    if string in NAME_FORMATTERS:
300
        return string
301
    else:
302
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
303
304
305
def parse_timer(string):
306
    return str(load_timer(string))
307
308
309
def parse_sort(string):
310
    string = string.lower().strip()
311
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
312
        raise argparse.ArgumentTypeError(
313
            "Unacceptable value: %r. "
314
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
315
            "'stddev', 'name', 'fullname'." % string)
316
    return string
317
318
319
def parse_columns(string):
320
    columns = [str.strip(s) for s in string.lower().split(',')]
321
    invalid = set(columns) - set(ALLOWED_COLUMNS)
322
    if invalid:
323
        # there are extra items in columns!
324
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
325
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
326
        raise argparse.ArgumentTypeError(msg)
327
    return columns
328
329
330
def parse_rounds(string):
331
    try:
332
        value = int(string)
333
    except ValueError as exc:
334
        raise argparse.ArgumentTypeError(exc)
335
    else:
336
        if value < 1:
337
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
338
        return value
339
340
341
def parse_seconds(string):
342
    try:
343
        return SecondsDecimal(string).as_string
344
    except Exception as exc:
345
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
346
347
348
def parse_save(string):
349
    if not string:
350
        raise argparse.ArgumentTypeError("Can't be empty.")
351
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
352
    if illegal:
353
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
354
    return string
355
356
357
def parse_elasticsearch_storage(string, default_index="benchmark", default_doctype="benchmark"):
358
    storage_url = urlparse(string)
359
    hosts = ["{scheme}://{netloc}".format(scheme=storage_url.scheme, netloc=netloc) for netloc in storage_url.netloc.split(',')]
360
    index = default_index
361
    doctype = default_doctype
362
    if storage_url.path and storage_url.path != "/":
363
        splitted = storage_url.path.strip("/").split("/")
364
        index = splitted[0]
365
        if len(splitted) >= 2:
366
            doctype = splitted[1]
367
    query = parse_qs(storage_url.query)
368
    try:
369
        project_name = query["project_name"][0]
370
    except KeyError:
371
        project_name = get_project_name()
372
    return hosts, index, doctype, project_name
373
374
375
def load_storage(storage, **kwargs):
376
    if "://" not in storage:
377
        storage = "file://" + storage
378
    if storage.startswith("file://"):
379
        from .storage.file import FileStorage
380
        return FileStorage(storage[len("file://"):], **kwargs)
381
    elif storage.startswith("elasticsearch+"):
382
        from .storage.elasticsearch import ElasticsearchStorage
383
        # TODO update benchmark_autosave
384
        return ElasticsearchStorage(*parse_elasticsearch_storage(storage[len("elasticsearch+"):]), **kwargs)
385
    else:
386
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
387
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
388
389
390
def time_unit(value):
391
    if value < 1e-6:
392
        return "n", 1e9
393
    elif value < 1e-3:
394
        return "u", 1e6
395
    elif value < 1:
396
        return "m", 1e3
397
    else:
398
        return "", 1.
399
400
401
def format_time(value):
402
    unit, adjustment = time_unit(value)
403
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
404
405
406
class cached_property(object):
407
    def __init__(self, func):
408
        self.__doc__ = getattr(func, '__doc__')
409
        self.func = func
410
411
    def __get__(self, obj, cls):
412
        if obj is None:
413
            return self
414
        value = obj.__dict__[self.func.__name__] = self.func(obj)
415
        return value
416
417
418
def funcname(f):
419
    try:
420
        if isinstance(f, partial):
421
            return f.func.__name__
422
        else:
423
            return f.__name__
424
    except AttributeError:
425
        return str(f)
426
427
428
def clonefunc(f):
429
    """Deep clone the given function to create a new one.
430
431
    By default, the PyPy JIT specializes the assembler based on f.__code__:
432
    clonefunc makes sure that you will get a new function with a **different**
433
    __code__, so that PyPy will produce independent assembler. This is useful
434
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
435
    apples to apples.
436
437
    Use it with caution: if abused, this might easily produce an explosion of
438
    produced assembler.
439
440
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
441
    """
442
443
    # first of all, we clone the code object
444
    try:
445
        co = f.__code__
446
        if PY3:
447
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
448
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
449
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
450
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
451
        else:
452
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
453
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
454
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
455
        #
456
        # then, we clone the function itself, using the new co2
457
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
458
    except AttributeError:
459
        return f
460
461
462
def format_dict(obj):
463
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
464
465
466
class SafeJSONEncoder(json.JSONEncoder):
467
    def default(self, o):
468
        return "UNSERIALIZABLE[%r]" % o
469
470
471
def safe_dumps(obj, **kwargs):
472
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
473
474
475
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
476
    total = len(iterable)
477
478
    def progress_reporting_wrapper():
479
        for pos, item in enumerate(iterable):
480
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
481
            terminal_reporter.rewrite(string, black=True, bold=True)
482
            yield string, item
483
    return progress_reporting_wrapper()
484
485
486
def report_noprogress(iterable, *args, **kwargs):
487
    for pos, item in enumerate(iterable):
488
        yield "", item
489
490
491
def slugify(name):
492
    for c in "\/:*?<>| ":
493
        name = name.replace(c, '_').replace('__', '_')
494
    return name
495
496
497
def commonpath(paths):
498
    """Given a sequence of path names, returns the longest common sub-path."""
499
500
    if not paths:
501
        raise ValueError('commonpath() arg is an empty sequence')
502
503
    if isinstance(paths[0], bytes):
504
        sep = b'\\'
505
        altsep = b'/'
506
        curdir = b'.'
507
    else:
508
        sep = '\\'
509
        altsep = '/'
510
        curdir = '.'
511
512
    try:
513
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
514
        split_paths = [p.split(sep) for d, p in drivesplits]
515
516
        try:
517
            isabs, = set(p[:1] == sep for d, p in drivesplits)
518
        except ValueError:
519
            raise ValueError("Can't mix absolute and relative paths")
520
521
        # Check that all drive letters or UNC paths match. The check is made only
522
        # now otherwise type errors for mixing strings and bytes would not be
523
        # caught.
524
        if len(set(d for d, p in drivesplits)) != 1:
525
            raise ValueError("Paths don't have the same drive")
526
527
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
528
        common = path.split(sep)
529
        common = [c for c in common if c and c != curdir]
530
531
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
532
        s1 = min(split_paths)
533
        s2 = max(split_paths)
534
        for i, c in enumerate(s1):
535
            if c != s2[i]:
536
                common = common[:i]
537
                break
538
        else:
539
            common = common[:len(s1)]
540
541
        prefix = drive + sep if isabs else drive
542
        return prefix + sep.join(common)
543
    except (TypeError, AttributeError):
544
        genericpath._check_arg_types('commonpath', *paths)
545
        raise
546
547
548
def get_cprofile_functions(stats):
549
    """
550
    Convert pstats structure to list of sorted dicts about each function.
551
    """
552
    result = []
553
    # this assumes that you run py.test from project root dir
554
    project_dir_parent = os.path.dirname(os.getcwd())
555
556
    for function_info, run_info in stats.stats.items():
557
        file_path = function_info[0]
558
        if file_path.startswith(project_dir_parent):
559
            file_path = file_path[len(project_dir_parent):].lstrip('/')
560
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
561
562
        # if the function is recursive write number of 'total calls/primitive calls'
563
        if run_info[0] == run_info[1]:
564
            calls = str(run_info[0])
565
        else:
566
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
567
568
        result.append(dict(ncalls_recursion=calls,
569
                           ncalls=run_info[1],
570
                           tottime=run_info[2],
571
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
572
                           cumtime=run_info[3],
573
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
574
                           function_name=function_name))
575
576
    return result
577