Completed
Pull Request — master (#64)
by
unknown
01:25
created

get_branch_info()   B

Complexity

Conditions 6

Size

Total Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 1 Features 0
Metric Value
cc 6
c 2
b 1
f 0
dl 0
loc 17
rs 8

1 Method

Rating   Name   Duplication   Size   Complexity  
A cmd() 0 3 1
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import ntpath
8
import os
9
import platform
10
import re
11
import subprocess
12
import sys
13
import types
14
from datetime import datetime
15
from decimal import Decimal
16
from functools import partial
17
18
from .compat import PY3
19
20
try:
21
    from urllib.parse import urlparse, parse_qs
22
except ImportError:
23
    from urlparse import urlparse, parse_qs
24
25
try:
26
    from subprocess import check_output
27
except ImportError:
28
    def check_output(*popenargs, **kwargs):
29
        if 'stdout' in kwargs:
30
            raise ValueError('stdout argument not allowed, it will be overridden.')
31
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
32
        output, unused_err = process.communicate()
33
        retcode = process.poll()
34
        if retcode:
35
            cmd = kwargs.get("args")
36
            if cmd is None:
37
                cmd = popenargs[0]
38
            raise subprocess.CalledProcessError(retcode, cmd)
39
        return output
40
41
TIME_UNITS = {
42
    "": "Seconds",
43
    "m": "Miliseconds (ms)",
44
    "u": "Microseconds (us)",
45
    "n": "Nanoseconds (ns)"
46
}
47
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "rounds", "iterations"]
48
49
50
class SecondsDecimal(Decimal):
51
    def __float__(self):
52
        return float(super(SecondsDecimal, self).__str__())
53
54
    def __str__(self):
55
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
56
57
    @property
58
    def as_string(self):
59
        return super(SecondsDecimal, self).__str__()
60
61
62
class NameWrapper(object):
63
    def __init__(self, target):
64
        self.target = target
65
66
    def __str__(self):
67
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
68
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
69
        return name
70
71
    def __repr__(self):
72
        return "NameWrapper(%s)" % repr(self.target)
73
74
75
def get_tag(project_name=None):
76
    info = get_commit_info(project_name)
77
    parts = []
78
    if info['project']:
79
        parts.append(info['project'])
80
    parts.append(info['id'])
81
    parts.append(get_current_time())
82
    if info['dirty']:
83
        parts.append("uncommited-changes")
84
    return "_".join(parts)
85
86
87
def get_machine_id():
88
    return "%s-%s-%s-%s" % (
89
        platform.system(),
90
        platform.python_implementation(),
91
        ".".join(platform.python_version_tuple()[:2]),
92
        platform.architecture()[0]
93
    )
94
95
96
def get_project_name():
97
    if os.path.exists('.git'):
98
        try:
99
            project_address = check_output("git config --local remote.origin.url".split())
100
            if isinstance(project_address, bytes) and str != bytes:
101
                project_address = project_address.decode()
102
            project_name = re.findall(r'/([^/]*)\.git', project_address)[0]
103
            return project_name
104
        except (IndexError, subprocess.CalledProcessError):
105
            return os.path.basename(os.getcwd())
106
    elif os.path.exists('.hg'):
107
        try:
108
            project_address = check_output("hg path default".split())
109
            project_address = project_address.decode()
110
            project_name = project_address.split("/")[-1]
111
            return project_name.strip()
112
        except (IndexError, subprocess.CalledProcessError):
113
            return os.path.basename(os.getcwd())
114
    else:
115
        return os.path.basename(os.getcwd())
116
117
118
def get_branch_info():
119
    def cmd(s):
120
        args = s.split()
121
        return check_output(args, stderr=subprocess.STDOUT, universal_newlines=True)
122
    #
123
    try:
124
        if os.path.exists('.git'):
125
            branch = cmd('git rev-parse --abbrev-ref HEAD').strip()
126
            if branch == 'HEAD':
127
                return '(detached head)'
128
            return branch
129
        elif os.path.exists('.hg'):
130
            return cmd('hg branch').strip()
131
        else:
132
            return 'unknown'
133
    except subprocess.CalledProcessError as e:
134
        return 'ERROR: %s' % e.output
135
136
137
def get_commit_info(project_name=None):
138
    dirty = False
139
    commit = 'unversioned'
140
    project_name = project_name or get_project_name()
141
    branch = get_branch_info()
142
    try:
143
        if os.path.exists('.git'):
144
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
145
                                universal_newlines=True).strip()
146
            desc = desc.split('-')
147
            if desc[-1].strip() == 'dirty':
148
                dirty = True
149
                desc.pop()
150
            commit = desc[-1].strip('g')
151
        elif os.path.exists('.hg'):
152
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
153
            if desc[-1] == '+':
154
                dirty = True
155
            commit = desc.strip('+')
156
        return {
157
            'id': commit,
158
            'dirty': dirty,
159
            'project': project_name,
160
            'branch': branch,
161
        }
162
    except Exception as exc:
163
        return {
164
            'id': 'unknown',
165
            'dirty': dirty,
166
            'error': repr(exc),
167
            'project': project_name,
168
            'branch': branch,
169
        }
170
171
172
def get_current_time():
173
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
174
175
176
def first_or_value(obj, value):
177
    if obj:
178
        value, = obj
179
180
    return value
181
182
183
def short_filename(path, machine_id=None):
184
    parts = []
185
    try:
186
        last = len(path.parts) - 1
187
    except AttributeError:
188
        return str(path)
189
    for pos, part in enumerate(path.parts):
190
        if not pos and part == machine_id:
191
            continue
192
        if pos == last:
193
            part = part.rsplit('.', 1)[0]
194
            # if len(part) > 16:
195
            #     part = "%.13s..." % part
196
        parts.append(part)
197
    return '/'.join(parts)
198
199
200
def load_timer(string):
201
    if "." not in string:
202
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
203
    mod, attr = string.rsplit(".", 1)
204
    if mod == 'pep418':
205
        if PY3:
206
            import time
207
            return NameWrapper(getattr(time, attr))
208
        else:
209
            from . import pep418
210
            return NameWrapper(getattr(pep418, attr))
211
    else:
212
        __import__(mod)
213
        mod = sys.modules[mod]
214
        return NameWrapper(getattr(mod, attr))
215
216
217
class RegressionCheck(object):
218
    def __init__(self, field, threshold):
219
        self.field = field
220
        self.threshold = threshold
221
222
    def fails(self, current, compared):
223
        val = self.compute(current, compared)
224
        if val > self.threshold:
225
            return "Field %r has failed %s: %.9f > %.9f" % (
226
                self.field, self.__class__.__name__, val, self.threshold
227
            )
228
229
230
class PercentageRegressionCheck(RegressionCheck):
231
    def compute(self, current, compared):
232
        val = compared[self.field]
233
        if not val:
234
            return float("inf")
235
        return current[self.field] / val * 100 - 100
236
237
238
class DifferenceRegressionCheck(RegressionCheck):
239
    def compute(self, current, compared):
240
        return current[self.field] - compared[self.field]
241
242
243
def parse_compare_fail(string,
244
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
245
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?))$')):
246
    m = rex.match(string)
247
    if m:
248
        g = m.groupdict()
249
        if g['percentage']:
250
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
251
        elif g['difference']:
252
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
253
254
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
255
256
257
def parse_warmup(string):
258
    string = string.lower().strip()
259
    if string == "auto":
260
        return platform.python_implementation() == "PyPy"
261
    elif string in ["off", "false", "no"]:
262
        return False
263
    elif string in ["on", "true", "yes", ""]:
264
        return True
265
    else:
266
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
267
268
269
def name_formatter_short(bench):
270
    name = bench["name"]
271
    if bench["source"]:
272
        name = "%s (%.4s)" % (name, os.path.split(bench["source"])[-1])
273
    if name.startswith("test_"):
274
        name = name[5:]
275
    return name
276
277
278
def name_formatter_normal(bench):
279
    name = bench["name"]
280
    if bench["source"]:
281
        parts = bench["source"].split('/')
282
        parts[-1] = parts[-1][:12]
283
        name = "%s (%s)" % (name, '/'.join(parts))
284
    return name
285
286
287
def name_formatter_long(bench):
288
    if bench["source"]:
289
        return "%(fullname)s (%(source)s)" % bench
290
    else:
291
        return bench["fullname"]
292
293
294
NAME_FORMATTERS = {
295
    "short": name_formatter_short,
296
    "normal": name_formatter_normal,
297
    "long": name_formatter_long,
298
}
299
300
301
def parse_name_format(string):
302
    string = string.lower().strip()
303
    if string in NAME_FORMATTERS:
304
        return string
305
    else:
306
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
307
308
309
def parse_timer(string):
310
    return str(load_timer(string))
311
312
313
def parse_sort(string):
314
    string = string.lower().strip()
315
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
316
        raise argparse.ArgumentTypeError(
317
            "Unacceptable value: %r. "
318
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
319
            "'stddev', 'name', 'fullname'." % string)
320
    return string
321
322
323
def parse_columns(string):
324
    columns = [str.strip(s) for s in string.lower().split(',')]
325
    invalid = set(columns) - set(ALLOWED_COLUMNS)
326
    if invalid:
327
        # there are extra items in columns!
328
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
329
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
330
        raise argparse.ArgumentTypeError(msg)
331
    return columns
332
333
334
def parse_rounds(string):
335
    try:
336
        value = int(string)
337
    except ValueError as exc:
338
        raise argparse.ArgumentTypeError(exc)
339
    else:
340
        if value < 1:
341
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
342
        return value
343
344
345
def parse_seconds(string):
346
    try:
347
        return SecondsDecimal(string).as_string
348
    except Exception as exc:
349
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
350
351
352
def parse_save(string):
353
    if not string:
354
        raise argparse.ArgumentTypeError("Can't be empty.")
355
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
356
    if illegal:
357
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
358
    return string
359
360
361
def parse_elasticsearch_storage(string, default_index="benchmark", default_doctype="benchmark"):
362
    storage_url = urlparse(string)
363
    hosts = ["{scheme}://{netloc}".format(scheme=storage_url.scheme, netloc=netloc) for netloc in storage_url.netloc.split(',')]
364
    index = default_index
365
    doctype = default_doctype
366
    if storage_url.path and storage_url.path != "/":
367
        splitted = storage_url.path.strip("/").split("/")
368
        index = splitted[0]
369
        if len(splitted) >= 2:
370
            doctype = splitted[1]
371
    query = parse_qs(storage_url.query)
372
    try:
373
        project_name = query["project_name"][0]
374
    except KeyError:
375
        project_name = get_project_name()
376
    return hosts, index, doctype, project_name
377
378
379
def load_storage(storage, **kwargs):
380
    if "://" not in storage:
381
        storage = "file://" + storage
382
    if storage.startswith("file://"):
383
        from .storage.file import FileStorage
384
        return FileStorage(storage[len("file://"):], **kwargs)
385
    elif storage.startswith("elasticsearch+"):
386
        from .storage.elasticsearch import ElasticsearchStorage
387
        # TODO update benchmark_autosave
388
        return ElasticsearchStorage(*parse_elasticsearch_storage(storage[len("elasticsearch+"):]), **kwargs)
389
    else:
390
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
391
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
392
393
394
def time_unit(value):
395
    if value < 1e-6:
396
        return "n", 1e9
397
    elif value < 1e-3:
398
        return "u", 1e6
399
    elif value < 1:
400
        return "m", 1e3
401
    else:
402
        return "", 1.
403
404
405
def format_time(value):
406
    unit, adjustment = time_unit(value)
407
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
408
409
410
class cached_property(object):
411
    def __init__(self, func):
412
        self.__doc__ = getattr(func, '__doc__')
413
        self.func = func
414
415
    def __get__(self, obj, cls):
416
        if obj is None:
417
            return self
418
        value = obj.__dict__[self.func.__name__] = self.func(obj)
419
        return value
420
421
422
def funcname(f):
423
    try:
424
        if isinstance(f, partial):
425
            return f.func.__name__
426
        else:
427
            return f.__name__
428
    except AttributeError:
429
        return str(f)
430
431
432
def clonefunc(f):
433
    """Deep clone the given function to create a new one.
434
435
    By default, the PyPy JIT specializes the assembler based on f.__code__:
436
    clonefunc makes sure that you will get a new function with a **different**
437
    __code__, so that PyPy will produce independent assembler. This is useful
438
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
439
    apples to apples.
440
441
    Use it with caution: if abused, this might easily produce an explosion of
442
    produced assembler.
443
444
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
445
    """
446
447
    # first of all, we clone the code object
448
    try:
449
        co = f.__code__
450
        if PY3:
451
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
452
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
453
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
454
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
455
        else:
456
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
457
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
458
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
459
        #
460
        # then, we clone the function itself, using the new co2
461
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
462
    except AttributeError:
463
        return f
464
465
466
def format_dict(obj):
467
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
468
469
470
class SafeJSONEncoder(json.JSONEncoder):
471
    def default(self, o):
472
        return "UNSERIALIZABLE[%r]" % o
473
474
475
def safe_dumps(obj, **kwargs):
476
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
477
478
479
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
480
    total = len(iterable)
481
482
    def progress_reporting_wrapper():
483
        for pos, item in enumerate(iterable):
484
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
485
            terminal_reporter.rewrite(string, black=True, bold=True)
486
            yield string, item
487
    return progress_reporting_wrapper()
488
489
490
def report_noprogress(iterable, *args, **kwargs):
491
    for pos, item in enumerate(iterable):
492
        yield "", item
493
494
495
def slugify(name):
496
    for c in "\/:*?<>| ":
497
        name = name.replace(c, '_').replace('__', '_')
498
    return name
499
500
501
def commonpath(paths):
502
    """Given a sequence of path names, returns the longest common sub-path."""
503
504
    if not paths:
505
        raise ValueError('commonpath() arg is an empty sequence')
506
507
    if isinstance(paths[0], bytes):
508
        sep = b'\\'
509
        altsep = b'/'
510
        curdir = b'.'
511
    else:
512
        sep = '\\'
513
        altsep = '/'
514
        curdir = '.'
515
516
    try:
517
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
518
        split_paths = [p.split(sep) for d, p in drivesplits]
519
520
        try:
521
            isabs, = set(p[:1] == sep for d, p in drivesplits)
522
        except ValueError:
523
            raise ValueError("Can't mix absolute and relative paths")
524
525
        # Check that all drive letters or UNC paths match. The check is made only
526
        # now otherwise type errors for mixing strings and bytes would not be
527
        # caught.
528
        if len(set(d for d, p in drivesplits)) != 1:
529
            raise ValueError("Paths don't have the same drive")
530
531
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
532
        common = path.split(sep)
533
        common = [c for c in common if c and c != curdir]
534
535
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
536
        s1 = min(split_paths)
537
        s2 = max(split_paths)
538
        for i, c in enumerate(s1):
539
            if c != s2[i]:
540
                common = common[:i]
541
                break
542
        else:
543
            common = common[:len(s1)]
544
545
        prefix = drive + sep if isabs else drive
546
        return prefix + sep.join(common)
547
    except (TypeError, AttributeError):
548
        genericpath._check_arg_types('commonpath', *paths)
549
        raise
550
551
552
def get_cprofile_functions(stats):
553
    """
554
    Convert pstats structure to list of sorted dicts about each function.
555
    """
556
    result = []
557
    # this assumes that you run py.test from project root dir
558
    project_dir_parent = os.path.dirname(os.getcwd())
559
560
    for function_info, run_info in stats.stats.items():
561
        file_path = function_info[0]
562
        if file_path.startswith(project_dir_parent):
563
            file_path = file_path[len(project_dir_parent):].lstrip('/')
564
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
565
566
        # if the function is recursive write number of 'total calls/primitive calls'
567
        if run_info[0] == run_info[1]:
568
            calls = str(run_info[0])
569
        else:
570
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
571
572
        result.append(dict(ncalls_recursion=calls,
573
                           ncalls=run_info[1],
574
                           tottime=run_info[2],
575
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
576
                           cumtime=run_info[3],
577
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
578
                           function_name=function_name))
579
580
    return result
581