Completed
Pull Request — master (#95)
by
unknown
36s
created

name_formatter_trial()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
dl 0
loc 5
rs 9.4285
c 0
b 0
f 0
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Miliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "ops", "outliers", "rounds", "iterations"]
59
60
61
class SecondsDecimal(Decimal):
62
    def __float__(self):
63
        return float(super(SecondsDecimal, self).__str__())
64
65
    def __str__(self):
66
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
67
68
    @property
69
    def as_string(self):
70
        return super(SecondsDecimal, self).__str__()
71
72
73
class NameWrapper(object):
74
    def __init__(self, target):
75
        self.target = target
76
77
    def __str__(self):
78
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
79
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
80
        return name
81
82
    def __repr__(self):
83
        return "NameWrapper(%s)" % repr(self.target)
84
85
86
def get_tag(project_name=None):
87
    info = get_commit_info(project_name)
88
    parts = [info['id'], get_current_time()]
89
    if info['dirty']:
90
        parts.append("uncommited-changes")
91
    return "_".join(parts)
92
93
94
def get_machine_id():
95
    return "%s-%s-%s-%s" % (
96
        platform.system(),
97
        platform.python_implementation(),
98
        ".".join(platform.python_version_tuple()[:2]),
99
        platform.architecture()[0]
100
    )
101
102
103
class Fallback(object):
104
    def __init__(self, *exceptions):
105
        self.functions = []
106
        self.exceptions = exceptions
107
108
    def __call__(self, *args, **kwargs):
109
        for func in self.functions:
110
            try:
111
                return func(*args, **kwargs)
112
            except self.exceptions:
113
                continue
114
        else:
115
            raise NotImplementedError("Must have an fallback implementation for %s" % self.functions[0])
116
117
    def register(self, other):
118
        self.functions.append(other)
119
        return self
120
121
122
get_project_name = Fallback(IndexError, CalledProcessError, OSError)
123
124
125
@get_project_name.register
126
def get_project_name_git():
127
    project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
128
    if isinstance(project_address, bytes) and str != bytes:
129
        project_address = project_address.decode()
130
    project_name = [i for i in re.split(r'[/:\s\\]|\.git', project_address) if i][-1]
131
    return project_name.strip()
132
133
134
@get_project_name.register
135
def get_project_name_hg():
136
    project_address = check_output(['hg', 'path', 'default'])
137
    project_address = project_address.decode()
138
    project_name = project_address.split("/")[-1]
139
    return project_name.strip()
140
141
142
@get_project_name.register
143
def get_project_name_default():
144
    return basename(os.getcwd())
145
146
147
def in_any_parent(name, path=None):
148
    prev = None
149
    if not path:
150
        path = os.getcwd()
151
    while path and prev != path and not exists(join(path, name)):
152
        prev = path
153
        path = dirname(path)
154
    return exists(join(path, name))
155
156
157
def subprocess_output(cmd):
158
    return check_output(cmd.split(), stderr=subprocess.STDOUT, universal_newlines=True).strip()
159
160
161
def get_commit_info(project_name=None):
162
    dirty = False
163
    commit = 'unversioned'
164
    commit_time = None
165
    author_time = None
166
    project_name = project_name or get_project_name()
167
    branch = '(unknown)'
168
    try:
169
        if in_any_parent('.git'):
170
            desc = subprocess_output('git describe --dirty --always --long --abbrev=40')
171
            desc = desc.split('-')
172
            if desc[-1].strip() == 'dirty':
173
                dirty = True
174
                desc.pop()
175
            commit = desc[-1].strip('g')
176
            commit_time = subprocess_output('git show -s --pretty=format:"%cI"').strip('"')
177
            author_time = subprocess_output('git show -s --pretty=format:"%aI"').strip('"')
178
            branch = subprocess_output('git rev-parse --abbrev-ref HEAD')
179
            if branch == 'HEAD':
180
                branch = '(detached head)'
181
        elif in_any_parent('.hg'):
182
            desc = subprocess_output('hg id --id --debug')
183
            if desc[-1] == '+':
184
                dirty = True
185
            commit = desc.strip('+')
186
            commit_time = subprocess_output('hg tip --template "{date|rfc3339date}"').strip('"')
187
            branch = subprocess_output('hg branch')
188
        return {
189
            'id': commit,
190
            'time': commit_time,
191
            'author_time': author_time,
192
            'dirty': dirty,
193
            'project': project_name,
194
            'branch': branch,
195
        }
196
    except Exception as exc:
197
        return {
198
            'id': 'unknown',
199
            'time': None,
200
            'author_time': None,
201
            'dirty': dirty,
202
            'error': 'CalledProcessError({0.returncode}, {0.output!r})'.format(exc)
203
                     if isinstance(exc, CalledProcessError) else repr(exc),
204
            'project': project_name,
205
            'branch': branch,
206
        }
207
208
209
def get_current_time():
210
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
211
212
213
def first_or_value(obj, value):
214
    if obj:
215
        value, = obj
216
217
    return value
218
219
220
def short_filename(path, machine_id=None):
221
    parts = []
222
    try:
223
        last = len(path.parts) - 1
224
    except AttributeError:
225
        return str(path)
226
    for pos, part in enumerate(path.parts):
227
        if not pos and part == machine_id:
228
            continue
229
        if pos == last:
230
            part = part.rsplit('.', 1)[0]
231
            # if len(part) > 16:
232
            #     part = "%.13s..." % part
233
        parts.append(part)
234
    return '/'.join(parts)
235
236
237
def load_timer(string):
238
    if "." not in string:
239
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
240
    mod, attr = string.rsplit(".", 1)
241
    if mod == 'pep418':
242
        if PY3:
243
            import time
244
            return NameWrapper(getattr(time, attr))
245
        else:
246
            from . import pep418
247
            return NameWrapper(getattr(pep418, attr))
248
    else:
249
        __import__(mod)
250
        mod = sys.modules[mod]
251
        return NameWrapper(getattr(mod, attr))
252
253
254
class RegressionCheck(object):
255
    def __init__(self, field, threshold):
256
        self.field = field
257
        self.threshold = threshold
258
259
    def fails(self, current, compared):
260
        val = self.compute(current, compared)
261
        if val > self.threshold:
262
            return "Field %r has failed %s: %.9f > %.9f" % (
263
                self.field, self.__class__.__name__, val, self.threshold
264
            )
265
266
267
class PercentageRegressionCheck(RegressionCheck):
268
    def compute(self, current, compared):
269
        val = compared[self.field]
270
        if not val:
271
            return float("inf")
272
        return current[self.field] / val * 100 - 100
273
274
275
class DifferenceRegressionCheck(RegressionCheck):
276
    def compute(self, current, compared):
277
        return current[self.field] - compared[self.field]
278
279
280
def parse_compare_fail(string,
281
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
282
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
283
                                      '0-9]+)?))$')):
284
    m = rex.match(string)
285
    if m:
286
        g = m.groupdict()
287
        if g['percentage']:
288
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
289
        elif g['difference']:
290
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
291
292
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
293
294
295
def parse_warmup(string):
296
    string = string.lower().strip()
297
    if string == "auto":
298
        return platform.python_implementation() == "PyPy"
299
    elif string in ["off", "false", "no"]:
300
        return False
301
    elif string in ["on", "true", "yes", ""]:
302
        return True
303
    else:
304
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
305
306
307
def name_formatter_short(bench):
308
    name = bench["name"]
309
    if bench["source"]:
310
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
311
    if name.startswith("test_"):
312
        name = name[5:]
313
    return name
314
315
316
def name_formatter_normal(bench):
317
    name = bench["name"]
318
    if bench["source"]:
319
        parts = bench["source"].split('/')
320
        parts[-1] = parts[-1][:12]
321
        name = "%s (%s)" % (name, '/'.join(parts))
322
    return name
323
324
325
def name_formatter_long(bench):
326
    if bench["source"]:
327
        return "%(fullname)s (%(source)s)" % bench
328
    else:
329
        return bench["fullname"]
330
331
332
def name_formatter_trial(bench):
333
    if bench["source"]:
334
        return "%.4s" % split(bench["source"])[-1]
335
    else:
336
        return '????'
337
338
339
NAME_FORMATTERS = {
340
    "short": name_formatter_short,
341
    "normal": name_formatter_normal,
342
    "long": name_formatter_long,
343
    "trial": name_formatter_trial,
344
}
345
346
347
def parse_name_format(string):
348
    string = string.lower().strip()
349
    if string in NAME_FORMATTERS:
350
        return string
351
    else:
352
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
353
354
355
def parse_timer(string):
356
    return str(load_timer(string))
357
358
359
def parse_sort(string):
360
    string = string.lower().strip()
361
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
362
        raise argparse.ArgumentTypeError(
363
            "Unacceptable value: %r. "
364
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
365
            "'stddev', 'name', 'fullname'." % string)
366
    return string
367
368
369
def parse_columns(string):
370
    columns = [str.strip(s) for s in string.lower().split(',')]
371
    invalid = set(columns) - set(ALLOWED_COLUMNS)
372
    if invalid:
373
        # there are extra items in columns!
374
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
375
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
376
        raise argparse.ArgumentTypeError(msg)
377
    return columns
378
379
380
def parse_rounds(string):
381
    try:
382
        value = int(string)
383
    except ValueError as exc:
384
        raise argparse.ArgumentTypeError(exc)
385
    else:
386
        if value < 1:
387
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
388
        return value
389
390
391
def parse_seconds(string):
392
    try:
393
        return SecondsDecimal(string).as_string
394
    except Exception as exc:
395
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
396
397
398
def parse_save(string):
399
    if not string:
400
        raise argparse.ArgumentTypeError("Can't be empty.")
401
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
402
    if illegal:
403
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
404
    return string
405
406
407
def _parse_hosts(storage_url, netrc_file):
408
409
    # load creds from netrc file
410
    path = os.path.expanduser(netrc_file)
411
    creds = None
412
    if netrc_file and os.path.isfile(path):
413
        creds = netrc.netrc(path)
414
415
    # add creds to urls
416
    urls = []
417
    for netloc in storage_url.netloc.split(','):
418
        auth = ""
419
        if creds and '@' not in netloc:
420
            host = netloc.split(':').pop(0)
421
            res = creds.authenticators(host)
422
            if res:
423
                user, _, secret = res
424
                auth = "{user}:{secret}@".format(user=user, secret=secret)
425
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
426
                                                 netloc=netloc, auth=auth)
427
        urls.append(url)
428
    return urls
429
430
431
def parse_elasticsearch_storage(string, default_index="benchmark",
432
                                default_doctype="benchmark", netrc_file=''):
433
    storage_url = urlparse(string)
434
    hosts = _parse_hosts(storage_url, netrc_file)
435
    index = default_index
436
    doctype = default_doctype
437
    if storage_url.path and storage_url.path != "/":
438
        splitted = storage_url.path.strip("/").split("/")
439
        index = splitted[0]
440
        if len(splitted) >= 2:
441
            doctype = splitted[1]
442
    query = parse_qs(storage_url.query)
443
    try:
444
        project_name = query["project_name"][0]
445
    except KeyError:
446
        project_name = get_project_name()
447
    return hosts, index, doctype, project_name
448
449
450
def load_storage(storage, **kwargs):
451
    if "://" not in storage:
452
        storage = "file://" + storage
453
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
454
    if storage.startswith("file://"):
455
        from .storage.file import FileStorage
456
        return FileStorage(storage[len("file://"):], **kwargs)
457
    elif storage.startswith("elasticsearch+"):
458
        from .storage.elasticsearch import ElasticsearchStorage
459
        # TODO update benchmark_autosave
460
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
461
                                           netrc_file=netrc_file)
462
        return ElasticsearchStorage(*args, **kwargs)
463
    else:
464
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
465
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
466
467
468
def time_unit(value):
469
    if value < 1e-6:
470
        return "n", 1e9
471
    elif value < 1e-3:
472
        return "u", 1e6
473
    elif value < 1:
474
        return "m", 1e3
475
    else:
476
        return "", 1.
477
478
479
def operations_unit(value):
480
    if value > 1e+6:
481
        return "M", 1e-6
482
    if value > 1e+3:
483
        return "K", 1e-3
484
    return "", 1.
485
486
487
def format_time(value):
488
    unit, adjustment = time_unit(value)
489
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
490
491
492
class cached_property(object):
493
    def __init__(self, func):
494
        self.__doc__ = getattr(func, '__doc__')
495
        self.func = func
496
497
    def __get__(self, obj, cls):
498
        if obj is None:
499
            return self
500
        value = obj.__dict__[self.func.__name__] = self.func(obj)
501
        return value
502
503
504
def funcname(f):
505
    try:
506
        if isinstance(f, partial):
507
            return f.func.__name__
508
        else:
509
            return f.__name__
510
    except AttributeError:
511
        return str(f)
512
513
514
def clonefunc(f):
515
    """Deep clone the given function to create a new one.
516
517
    By default, the PyPy JIT specializes the assembler based on f.__code__:
518
    clonefunc makes sure that you will get a new function with a **different**
519
    __code__, so that PyPy will produce independent assembler. This is useful
520
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
521
    apples to apples.
522
523
    Use it with caution: if abused, this might easily produce an explosion of
524
    produced assembler.
525
526
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
527
    """
528
529
    # first of all, we clone the code object
530
    try:
531
        co = f.__code__
532
        if PY3:
533
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
534
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
535
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
536
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
537
        else:
538
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
539
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
540
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
541
        #
542
        # then, we clone the function itself, using the new co2
543
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
544
    except AttributeError:
545
        return f
546
547
548
def format_dict(obj):
549
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
550
551
552
class SafeJSONEncoder(json.JSONEncoder):
553
    def default(self, o):
554
        return "UNSERIALIZABLE[%r]" % o
555
556
557
def safe_dumps(obj, **kwargs):
558
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
559
560
561
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
562
    total = len(iterable)
563
564
    def progress_reporting_wrapper():
565
        for pos, item in enumerate(iterable):
566
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
567
            terminal_reporter.rewrite(string, black=True, bold=True)
568
            yield string, item
569
570
    return progress_reporting_wrapper()
571
572
573
def report_noprogress(iterable, *args, **kwargs):
574
    for pos, item in enumerate(iterable):
575
        yield "", item
576
577
578
def slugify(name):
579
    for c in "\/:*?<>| ":
580
        name = name.replace(c, '_').replace('__', '_')
581
    return name
582
583
584
def commonpath(paths):
585
    """Given a sequence of path names, returns the longest common sub-path."""
586
587
    if not paths:
588
        raise ValueError('commonpath() arg is an empty sequence')
589
590
    if isinstance(paths[0], bytes):
591
        sep = b'\\'
592
        altsep = b'/'
593
        curdir = b'.'
594
    else:
595
        sep = '\\'
596
        altsep = '/'
597
        curdir = '.'
598
599
    try:
600
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
601
        split_paths = [p.split(sep) for d, p in drivesplits]
602
603
        try:
604
            isabs, = set(p[:1] == sep for d, p in drivesplits)
605
        except ValueError:
606
            raise ValueError("Can't mix absolute and relative paths")
607
608
        # Check that all drive letters or UNC paths match. The check is made only
609
        # now otherwise type errors for mixing strings and bytes would not be
610
        # caught.
611
        if len(set(d for d, p in drivesplits)) != 1:
612
            raise ValueError("Paths don't have the same drive")
613
614
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
615
        common = path.split(sep)
616
        common = [c for c in common if c and c != curdir]
617
618
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
619
        s1 = min(split_paths)
620
        s2 = max(split_paths)
621
        for i, c in enumerate(s1):
622
            if c != s2[i]:
623
                common = common[:i]
624
                break
625
        else:
626
            common = common[:len(s1)]
627
628
        prefix = drive + sep if isabs else drive
629
        return prefix + sep.join(common)
630
    except (TypeError, AttributeError):
631
        genericpath._check_arg_types('commonpath', *paths)
632
        raise
633
634
635
def get_cprofile_functions(stats):
636
    """
637
    Convert pstats structure to list of sorted dicts about each function.
638
    """
639
    result = []
640
    # this assumes that you run py.test from project root dir
641
    project_dir_parent = dirname(os.getcwd())
642
643
    for function_info, run_info in stats.stats.items():
644
        file_path = function_info[0]
645
        if file_path.startswith(project_dir_parent):
646
            file_path = file_path[len(project_dir_parent):].lstrip('/')
647
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
648
649
        # if the function is recursive write number of 'total calls/primitive calls'
650
        if run_info[0] == run_info[1]:
651
            calls = str(run_info[0])
652
        else:
653
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
654
655
        result.append(dict(ncalls_recursion=calls,
656
                           ncalls=run_info[1],
657
                           tottime=run_info[2],
658
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
659
                           cumtime=run_info[3],
660
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
661
                           function_name=function_name))
662
663
    return result
664