Completed
Pull Request — master (#95)
by
unknown
40s
created

name_formatter_trial()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
dl 0
loc 5
rs 9.4285
c 0
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Milliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
59
60
61
class SecondsDecimal(Decimal):
62
    def __float__(self):
63
        return float(super(SecondsDecimal, self).__str__())
64
65
    def __str__(self):
66
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
67
68
    @property
69
    def as_string(self):
70
        return super(SecondsDecimal, self).__str__()
71
72
73
class NameWrapper(object):
74
    def __init__(self, target):
75
        self.target = target
76
77
    def __str__(self):
78
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
79
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
80
        return name
81
82
    def __repr__(self):
83
        return "NameWrapper(%s)" % repr(self.target)
84
85
86
def get_tag(project_name=None):
87
    info = get_commit_info(project_name)
88
    parts = [info['id'], get_current_time()]
89
    if info['dirty']:
90
        parts.append("uncommited-changes")
91
    return "_".join(parts)
92
93
94
def get_machine_id():
95
    return "%s-%s-%s-%s" % (
96
        platform.system(),
97
        platform.python_implementation(),
98
        ".".join(platform.python_version_tuple()[:2]),
99
        platform.architecture()[0]
100
    )
101
102
103
class Fallback(object):
104
    def __init__(self, fallback, exceptions):
105
        self.fallback = fallback
106
        self.functions = []
107
        self.exceptions = exceptions
108
109
    def __call__(self, *args, **kwargs):
110
        for func in self.functions:
111
            try:
112
                value = func(*args, **kwargs)
113
            except self.exceptions:
114
                continue
115
            else:
116
                if not value:
117
                    continue
118
        else:
119
            return self.fallback(*args, **kwargs)
120
121
    def register(self, other):
122
        self.functions.append(other)
123
        return self
124
125
126
@partial(Fallback, exceptions=(IndexError, CalledProcessError, OSError))
127
def get_project_name():
128
    return basename(os.getcwd())
129
130
131
@get_project_name.register
132
def get_project_name_git():
133
    is_git = check_output(['git', 'rev-parse', '--git-dir'], stderr=subprocess.STDOUT)
134
    if is_git:
135
        project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
136
        if isinstance(project_address, bytes) and str != bytes:
137
            project_address = project_address.decode()
138
        project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
139
        return project_name.strip()
140
141
142
@get_project_name.register
143
def get_project_name_hg():
144
    project_address = check_output(['hg', 'path', 'default', '--quiet'])
145
    project_address = project_address.decode()
146
    project_name = project_address.split("/")[-1]
147
    return project_name.strip()
148
149
150
def in_any_parent(name, path=None):
151
    prev = None
152
    if not path:
153
        path = os.getcwd()
154
    while path and prev != path and not exists(join(path, name)):
155
        prev = path
156
        path = dirname(path)
157
    return exists(join(path, name))
158
159
160
def subprocess_output(cmd):
161
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
162
163
164
def get_commit_info(project_name=None):
165
    dirty = False
166
    commit = 'unversioned'
167
    commit_time = None
168
    author_time = None
169
    project_name = project_name or get_project_name()
170
    branch = '(unknown)'
171
    try:
172
        if in_any_parent('.git'):
173
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
174
            desc = desc.split('-')
175
            if desc[-1].strip() == 'dirty':
176
                dirty = True
177
                desc.pop()
178
            commit = desc[-1].strip('g')
179
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
180
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
181
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
182
            if branch == 'HEAD':
183
                branch = '(detached head)'
184
        elif in_any_parent('.hg'):
185
            desc = subprocess_output('hg id --id --debug')
186
            if desc[-1] == '+':
187
                dirty = True
188
            commit = desc.strip('+')
189
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
190
            branch = subprocess_output('hg branch')
191
        return {
192
            'id': commit,
193
            'time': commit_time,
194
            'author_time': author_time,
195
            'dirty': dirty,
196
            'project': project_name,
197
            'branch': branch,
198
        }
199
    except Exception as exc:
200
        return {
201
            'id': 'unknown',
202
            'time': None,
203
            'author_time': None,
204
            'dirty': dirty,
205
            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
206
                     if isinstance(exc, CalledProcessError) else repr(exc),
207
            'project': project_name,
208
            'branch': branch,
209
        }
210
211
212
def get_current_time():
213
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
214
215
216
def first_or_value(obj, value):
217
    if obj:
218
        value, = obj
219
220
    return value
221
222
223
def short_filename(path, machine_id=None):
224
    parts = []
225
    try:
226
        last = len(path.parts) - 1
227
    except AttributeError:
228
        return str(path)
229
    for pos, part in enumerate(path.parts):
230
        if not pos and part == machine_id:
231
            continue
232
        if pos == last:
233
            part = part.rsplit('.', 1)[0]
234
            # if len(part) > 16:
235
            #     part = "%.13s..." % part
236
        parts.append(part)
237
    return '/'.join(parts)
238
239
240
def load_timer(string):
241
    if "." not in string:
242
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
243
    mod, attr = string.rsplit(".", 1)
244
    if mod == 'pep418':
245
        if PY3:
246
            import time
247
            return NameWrapper(getattr(time, attr))
248
        else:
249
            from . import pep418
250
            return NameWrapper(getattr(pep418, attr))
251
    else:
252
        __import__(mod)
253
        mod = sys.modules[mod]
254
        return NameWrapper(getattr(mod, attr))
255
256
257
class RegressionCheck(object):
258
    def __init__(self, field, threshold):
259
        self.field = field
260
        self.threshold = threshold
261
262
    def fails(self, current, compared):
263
        val = self.compute(current, compared)
264
        if val > self.threshold:
265
            return "Field %r has failed %s: %.9f > %.9f" % (
266
                self.field, self.__class__.__name__, val, self.threshold
267
            )
268
269
270
class PercentageRegressionCheck(RegressionCheck):
271
    def compute(self, current, compared):
272
        val = compared[self.field]
273
        if not val:
274
            return float("inf")
275
        return current[self.field] / val * 100 - 100
276
277
278
class DifferenceRegressionCheck(RegressionCheck):
279
    def compute(self, current, compared):
280
        return current[self.field] - compared[self.field]
281
282
283
def parse_compare_fail(string,
284
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
285
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
286
                                      '0-9]+)?))$')):
287
    m = rex.match(string)
288
    if m:
289
        g = m.groupdict()
290
        if g['percentage']:
291
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
292
        elif g['difference']:
293
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
294
295
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
296
297
298
def parse_warmup(string):
299
    string = string.lower().strip()
300
    if string == "auto":
301
        return platform.python_implementation() == "PyPy"
302
    elif string in ["off", "false", "no"]:
303
        return False
304
    elif string in ["on", "true", "yes", ""]:
305
        return True
306
    else:
307
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
308
309
310
def name_formatter_short(bench):
311
    name = bench["name"]
312
    if bench["source"]:
313
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
314
    if name.startswith("test_"):
315
        name = name[5:]
316
    return name
317
318
319
def name_formatter_normal(bench):
320
    name = bench["name"]
321
    if bench["source"]:
322
        parts = bench["source"].split('/')
323
        parts[-1] = parts[-1][:12]
324
        name = "%s (%s)" % (name, '/'.join(parts))
325
    return name
326
327
328
def name_formatter_long(bench):
329
    if bench["source"]:
330
        return "%(fullname)s (%(source)s)" % bench
331
    else:
332
        return bench["fullname"]
333
334
335
def name_formatter_trial(bench):
336
    if bench["source"]:
337
        return "%.4s" % split(bench["source"])[-1]
338
    else:
339
        return '????'
340
341
342
NAME_FORMATTERS = {
343
    "short": name_formatter_short,
344
    "normal": name_formatter_normal,
345
    "long": name_formatter_long,
346
    "trial": name_formatter_trial,
347
}
348
349
350
def parse_name_format(string):
351
    string = string.lower().strip()
352
    if string in NAME_FORMATTERS:
353
        return string
354
    else:
355
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
356
357
358
def parse_timer(string):
359
    return str(load_timer(string))
360
361
362
def parse_sort(string):
363
    string = string.lower().strip()
364
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
365
        raise argparse.ArgumentTypeError(
366
            "Unacceptable value: %r. "
367
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
368
            "'stddev', 'name', 'fullname'." % string)
369
    return string
370
371
372
def parse_columns(string):
373
    columns = [str.strip(s) for s in string.lower().split(',')]
374
    invalid = set(columns) - set(ALLOWED_COLUMNS)
375
    if invalid:
376
        # there are extra items in columns!
377
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
378
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
379
        raise argparse.ArgumentTypeError(msg)
380
    return columns
381
382
383
def parse_rounds(string):
384
    try:
385
        value = int(string)
386
    except ValueError as exc:
387
        raise argparse.ArgumentTypeError(exc)
388
    else:
389
        if value < 1:
390
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
391
        return value
392
393
394
def parse_seconds(string):
395
    try:
396
        return SecondsDecimal(string).as_string
397
    except Exception as exc:
398
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
399
400
401
def parse_save(string):
402
    if not string:
403
        raise argparse.ArgumentTypeError("Can't be empty.")
404
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
405
    if illegal:
406
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
407
    return string
408
409
410
def _parse_hosts(storage_url, netrc_file):
411
412
    # load creds from netrc file
413
    path = os.path.expanduser(netrc_file)
414
    creds = None
415
    if netrc_file and os.path.isfile(path):
416
        creds = netrc.netrc(path)
417
418
    # add creds to urls
419
    urls = []
420
    for netloc in storage_url.netloc.split(','):
421
        auth = ""
422
        if creds and '@' not in netloc:
423
            host = netloc.split(':').pop(0)
424
            res = creds.authenticators(host)
425
            if res:
426
                user, _, secret = res
427
                auth = "{user}:{secret}@".format(user=user, secret=secret)
428
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
429
                                                 netloc=netloc, auth=auth)
430
        urls.append(url)
431
    return urls
432
433
434
def parse_elasticsearch_storage(string, default_index="benchmark",
435
                                default_doctype="benchmark", netrc_file=''):
436
    storage_url = urlparse(string)
437
    hosts = _parse_hosts(storage_url, netrc_file)
438
    index = default_index
439
    doctype = default_doctype
440
    if storage_url.path and storage_url.path != "/":
441
        splitted = storage_url.path.strip("/").split("/")
442
        index = splitted[0]
443
        if len(splitted) >= 2:
444
            doctype = splitted[1]
445
    query = parse_qs(storage_url.query)
446
    try:
447
        project_name = query["project_name"][0]
448
    except KeyError:
449
        project_name = get_project_name()
450
    return hosts, index, doctype, project_name
451
452
453
def load_storage(storage, **kwargs):
454
    if "://" not in storage:
455
        storage = "file://" + storage
456
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
457
    if storage.startswith("file://"):
458
        from .storage.file import FileStorage
459
        return FileStorage(storage[len("file://"):], **kwargs)
460
    elif storage.startswith("elasticsearch+"):
461
        from .storage.elasticsearch import ElasticsearchStorage
462
        # TODO update benchmark_autosave
463
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
464
                                           netrc_file=netrc_file)
465
        return ElasticsearchStorage(*args, **kwargs)
466
    else:
467
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
468
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
469
470
471
def time_unit(value):
472
    if value < 1e-6:
473
        return "n", 1e9
474
    elif value < 1e-3:
475
        return "u", 1e6
476
    elif value < 1:
477
        return "m", 1e3
478
    else:
479
        return "", 1.
480
481
482
def operations_unit(value):
483
    if value > 1e+6:
484
        return "M", 1e-6
485
    if value > 1e+3:
486
        return "K", 1e-3
487
    return "", 1.
488
489
490
def format_time(value):
491
    unit, adjustment = time_unit(value)
492
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
493
494
495
class cached_property(object):
496
    def __init__(self, func):
497
        self.__doc__ = getattr(func, '__doc__')
498
        self.func = func
499
500
    def __get__(self, obj, cls):
501
        if obj is None:
502
            return self
503
        value = obj.__dict__[self.func.__name__] = self.func(obj)
504
        return value
505
506
507
def funcname(f):
508
    try:
509
        if isinstance(f, partial):
510
            return f.func.__name__
511
        else:
512
            return f.__name__
513
    except AttributeError:
514
        return str(f)
515
516
517
def clonefunc(f):
518
    """Deep clone the given function to create a new one.
519
520
    By default, the PyPy JIT specializes the assembler based on f.__code__:
521
    clonefunc makes sure that you will get a new function with a **different**
522
    __code__, so that PyPy will produce independent assembler. This is useful
523
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
524
    apples to apples.
525
526
    Use it with caution: if abused, this might easily produce an explosion of
527
    produced assembler.
528
529
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
530
    """
531
532
    # first of all, we clone the code object
533
    try:
534
        co = f.__code__
535
        if PY3:
536
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
537
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
538
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
539
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
540
        else:
541
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
542
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
543
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
544
        #
545
        # then, we clone the function itself, using the new co2
546
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
547
    except AttributeError:
548
        return f
549
550
551
def format_dict(obj):
552
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
553
554
555
class SafeJSONEncoder(json.JSONEncoder):
556
    def default(self, o):
557
        return "UNSERIALIZABLE[%r]" % o
558
559
560
def safe_dumps(obj, **kwargs):
561
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
562
563
564
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
565
    total = len(iterable)
566
567
    def progress_reporting_wrapper():
568
        for pos, item in enumerate(iterable):
569
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
570
            terminal_reporter.rewrite(string, black=True, bold=True)
571
            yield string, item
572
573
    return progress_reporting_wrapper()
574
575
576
def report_noprogress(iterable, *args, **kwargs):
577
    for pos, item in enumerate(iterable):
578
        yield "", item
579
580
581
def slugify(name):
582
    for c in "\/:*?<>| ":
583
        name = name.replace(c, '_').replace('__', '_')
584
    return name
585
586
587
def commonpath(paths):
588
    """Given a sequence of path names, returns the longest common sub-path."""
589
590
    if not paths:
591
        raise ValueError('commonpath() arg is an empty sequence')
592
593
    if isinstance(paths[0], bytes):
594
        sep = b'\\'
595
        altsep = b'/'
596
        curdir = b'.'
597
    else:
598
        sep = '\\'
599
        altsep = '/'
600
        curdir = '.'
601
602
    try:
603
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
604
        split_paths = [p.split(sep) for d, p in drivesplits]
605
606
        try:
607
            isabs, = set(p[:1] == sep for d, p in drivesplits)
608
        except ValueError:
609
            raise ValueError("Can't mix absolute and relative paths")
610
611
        # Check that all drive letters or UNC paths match. The check is made only
612
        # now otherwise type errors for mixing strings and bytes would not be
613
        # caught.
614
        if len(set(d for d, p in drivesplits)) != 1:
615
            raise ValueError("Paths don't have the same drive")
616
617
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
618
        common = path.split(sep)
619
        common = [c for c in common if c and c != curdir]
620
621
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
622
        s1 = min(split_paths)
623
        s2 = max(split_paths)
624
        for i, c in enumerate(s1):
625
            if c != s2[i]:
626
                common = common[:i]
627
                break
628
        else:
629
            common = common[:len(s1)]
630
631
        prefix = drive + sep if isabs else drive
632
        return prefix + sep.join(common)
633
    except (TypeError, AttributeError):
634
        genericpath._check_arg_types('commonpath', *paths)
635
        raise
636
637
638
def get_cprofile_functions(stats):
639
    """
640
    Convert pstats structure to list of sorted dicts about each function.
641
    """
642
    result = []
643
    # this assumes that you run py.test from project root dir
644
    project_dir_parent = dirname(os.getcwd())
645
646
    for function_info, run_info in stats.stats.items():
647
        file_path = function_info[0]
648
        if file_path.startswith(project_dir_parent):
649
            file_path = file_path[len(project_dir_parent):].lstrip('/')
650
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
651
652
        # if the function is recursive write number of 'total calls/primitive calls'
653
        if run_info[0] == run_info[1]:
654
            calls = str(run_info[0])
655
        else:
656
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
657
658
        result.append(dict(ncalls_recursion=calls,
659
                           ncalls=run_info[1],
660
                           tottime=run_info[2],
661
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
662
                           cumtime=run_info[3],
663
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
664
                           function_name=function_name))
665
666
    return result
667