Completed
Pull Request — master (#68)
by Swen
01:17
created

is_list_like()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 3
rs 10
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import ntpath
8
import os
9
import platform
10
import re
11
import subprocess
12
import sys
13
import types
14
from collections import Iterable
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
19
from .compat import PY3
20
21
try:
22
    from urllib.parse import urlparse, parse_qs
23
except ImportError:
24
    from urlparse import urlparse, parse_qs
25
26
try:
27
    from subprocess import check_output
28
except ImportError:
29
    class CalledProcessError(subprocess.CalledProcessError):
30
        def __init__(self, returncode, cmd, output=None):
31
            super(CalledProcessError, self).__init__(returncode, cmd)
32
            self.output = output
33
34
    def check_output(*popenargs, **kwargs):
35
        if 'stdout' in kwargs:
36
            raise ValueError('stdout argument not allowed, it will be overridden.')
37
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
38
        output, unused_err = process.communicate()
39
        retcode = process.poll()
40
        if retcode:
41
            cmd = kwargs.get("args")
42
            if cmd is None:
43
                cmd = popenargs[0]
44
            raise CalledProcessError(retcode, cmd, output)
45
        return output
46
47
TIME_UNITS = {
48
    "": "Seconds",
49
    "m": "Miliseconds (ms)",
50
    "u": "Microseconds (us)",
51
    "n": "Nanoseconds (ns)"
52
}
53
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "rounds", "iterations"]
54
55
56
class SecondsDecimal(Decimal):
57
    def __float__(self):
58
        return float(super(SecondsDecimal, self).__str__())
59
60
    def __str__(self):
61
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
62
63
    @property
64
    def as_string(self):
65
        return super(SecondsDecimal, self).__str__()
66
67
68
class NameWrapper(object):
69
    def __init__(self, target):
70
        self.target = target
71
72
    def __str__(self):
73
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
74
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
75
        return name
76
77
    def __repr__(self):
78
        return "NameWrapper(%s)" % repr(self.target)
79
80
81
def get_tag(project_name=None):
82
    info = get_commit_info(project_name)
83
    parts = [info['id'], get_current_time()]
84
    if info['dirty']:
85
        parts.append("uncommited-changes")
86
    return "_".join(parts)
87
88
89
def get_machine_id():
90
    return "%s-%s-%s-%s" % (
91
        platform.system(),
92
        platform.python_implementation(),
93
        ".".join(platform.python_version_tuple()[:2]),
94
        platform.architecture()[0]
95
    )
96
97
98
def get_project_name():
99
    if os.path.exists('.git'):
100
        try:
101
            project_address = check_output("git config --local remote.origin.url".split())
102
            if isinstance(project_address, bytes) and str != bytes:
103
                project_address = project_address.decode()
104
            project_name = re.findall(r'/([^/]*)\.git', project_address)[0]
105
            return project_name
106
        except (IndexError, subprocess.CalledProcessError):
107
            return os.path.basename(os.getcwd())
108
    elif os.path.exists('.hg'):
109
        try:
110
            project_address = check_output("hg path default".split())
111
            project_address = project_address.decode()
112
            project_name = project_address.split("/")[-1]
113
            return project_name.strip()
114
        except (IndexError, subprocess.CalledProcessError):
115
            return os.path.basename(os.getcwd())
116
    else:
117
        return os.path.basename(os.getcwd())
118
119
120
def get_branch_info():
121
    def cmd(s):
122
        args = s.split()
123
        return check_output(args, stderr=subprocess.STDOUT, universal_newlines=True)
124
    try:
125
        if os.path.exists('.git'):
126
            branch = cmd('git rev-parse --abbrev-ref HEAD').strip()
127
            if branch == 'HEAD':
128
                return '(detached head)'
129
            return branch
130
        elif os.path.exists('.hg'):
131
            return cmd('hg branch').strip()
132
        else:
133
            return '(unknown vcs)'
134
    except subprocess.CalledProcessError as e:
135
        return '(error: %s)' % e.output.strip()
136
137
138
def get_commit_info(project_name=None):
139
    dirty = False
140
    commit = 'unversioned'
141
    project_name = project_name or get_project_name()
142
    branch = get_branch_info()
143
    try:
144
        if os.path.exists('.git'):
145
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
146
                                universal_newlines=True).strip()
147
            desc = desc.split('-')
148
            if desc[-1].strip() == 'dirty':
149
                dirty = True
150
                desc.pop()
151
            commit = desc[-1].strip('g')
152
        elif os.path.exists('.hg'):
153
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
154
            if desc[-1] == '+':
155
                dirty = True
156
            commit = desc.strip('+')
157
        return {
158
            'id': commit,
159
            'dirty': dirty,
160
            'project': project_name,
161
            'branch': branch,
162
        }
163
    except Exception as exc:
164
        return {
165
            'id': 'unknown',
166
            'dirty': dirty,
167
            'error': repr(exc),
168
            'project': project_name,
169
            'branch': branch,
170
        }
171
172
173
def get_current_time():
174
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
175
176
177
def first_or_value(obj, value):
178
    if obj:
179
        value, = obj
180
181
    return value
182
183
184
def short_filename(path, machine_id=None):
185
    parts = []
186
    try:
187
        last = len(path.parts) - 1
188
    except AttributeError:
189
        return str(path)
190
    for pos, part in enumerate(path.parts):
191
        if not pos and part == machine_id:
192
            continue
193
        if pos == last:
194
            part = part.rsplit('.', 1)[0]
195
            # if len(part) > 16:
196
            #     part = "%.13s..." % part
197
        parts.append(part)
198
    return '/'.join(parts)
199
200
201
def load_timer(string):
202
    if "." not in string:
203
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
204
    mod, attr = string.rsplit(".", 1)
205
    if mod == 'pep418':
206
        if PY3:
207
            import time
208
            return NameWrapper(getattr(time, attr))
209
        else:
210
            from . import pep418
211
            return NameWrapper(getattr(pep418, attr))
212
    else:
213
        __import__(mod)
214
        mod = sys.modules[mod]
215
        return NameWrapper(getattr(mod, attr))
216
217
218
class RegressionCheck(object):
219
    def __init__(self, field, threshold):
220
        self.field = field
221
        self.threshold = threshold
222
223
    def fails(self, current, compared):
224
        val = self.compute(current, compared)
225
        if val > self.threshold:
226
            return "Field %r has failed %s: %.9f > %.9f" % (
227
                self.field, self.__class__.__name__, val, self.threshold
228
            )
229
230
231
class PercentageRegressionCheck(RegressionCheck):
232
    def compute(self, current, compared):
233
        val = compared[self.field]
234
        if not val:
235
            return float("inf")
236
        return current[self.field] / val * 100 - 100
237
238
239
class DifferenceRegressionCheck(RegressionCheck):
240
    def compute(self, current, compared):
241
        return current[self.field] - compared[self.field]
242
243
244
def parse_compare_fail(string,
245
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
246
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?))$')):
247
    m = rex.match(string)
248
    if m:
249
        g = m.groupdict()
250
        if g['percentage']:
251
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
252
        elif g['difference']:
253
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
254
255
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
256
257
258
def parse_warmup(string):
259
    string = string.lower().strip()
260
    if string == "auto":
261
        return platform.python_implementation() == "PyPy"
262
    elif string in ["off", "false", "no"]:
263
        return False
264
    elif string in ["on", "true", "yes", ""]:
265
        return True
266
    else:
267
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
268
269
270
def name_formatter_short(bench):
271
    name = bench["name"]
272
    if bench["source"]:
273
        name = "%s (%.4s)" % (name, os.path.split(bench["source"])[-1])
274
    if name.startswith("test_"):
275
        name = name[5:]
276
    return name
277
278
279
def name_formatter_normal(bench):
280
    name = bench["name"]
281
    if bench["source"]:
282
        parts = bench["source"].split('/')
283
        parts[-1] = parts[-1][:12]
284
        name = "%s (%s)" % (name, '/'.join(parts))
285
    return name
286
287
288
def name_formatter_long(bench):
289
    if bench["source"]:
290
        return "%(fullname)s (%(source)s)" % bench
291
    else:
292
        return bench["fullname"]
293
294
295
NAME_FORMATTERS = {
296
    "short": name_formatter_short,
297
    "normal": name_formatter_normal,
298
    "long": name_formatter_long,
299
}
300
301
302
def parse_name_format(string):
303
    string = string.lower().strip()
304
    if string in NAME_FORMATTERS:
305
        return string
306
    else:
307
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
308
309
310
def parse_timer(string):
311
    return str(load_timer(string))
312
313
314
def parse_sort(string):
315
    string = string.lower().strip()
316
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
317
        raise argparse.ArgumentTypeError(
318
            "Unacceptable value: %r. "
319
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
320
            "'stddev', 'name', 'fullname'." % string)
321
    return string
322
323
324
def parse_columns(string):
325
    columns = [str.strip(s) for s in string.lower().split(',')]
326
    invalid = set(columns) - set(ALLOWED_COLUMNS)
327
    if invalid:
328
        # there are extra items in columns!
329
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
330
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
331
        raise argparse.ArgumentTypeError(msg)
332
    return columns
333
334
335
def parse_rounds(string):
336
    try:
337
        value = int(string)
338
    except ValueError as exc:
339
        raise argparse.ArgumentTypeError(exc)
340
    else:
341
        if value < 1:
342
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
343
        return value
344
345
346
def parse_seconds(string):
347
    try:
348
        return SecondsDecimal(string).as_string
349
    except Exception as exc:
350
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
351
352
353
def parse_save(string):
354
    if not string:
355
        raise argparse.ArgumentTypeError("Can't be empty.")
356
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
357
    if illegal:
358
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
359
    return string
360
361
362
def parse_elasticsearch_storage(string, default_index="benchmark", default_doctype="benchmark"):
363
    storage_url = urlparse(string)
364
    hosts = ["{scheme}://{netloc}".format(scheme=storage_url.scheme, netloc=netloc) for netloc in storage_url.netloc.split(',')]
365
    index = default_index
366
    doctype = default_doctype
367
    if storage_url.path and storage_url.path != "/":
368
        splitted = storage_url.path.strip("/").split("/")
369
        index = splitted[0]
370
        if len(splitted) >= 2:
371
            doctype = splitted[1]
372
    query = parse_qs(storage_url.query)
373
    try:
374
        project_name = query["project_name"][0]
375
    except KeyError:
376
        project_name = get_project_name()
377
    return hosts, index, doctype, project_name
378
379
380
def load_storage(storage, **kwargs):
381
    if "://" not in storage:
382
        storage = "file://" + storage
383
    if storage.startswith("file://"):
384
        from .storage.file import FileStorage
385
        return FileStorage(storage[len("file://"):], **kwargs)
386
    elif storage.startswith("elasticsearch+"):
387
        from .storage.elasticsearch import ElasticsearchStorage
388
        # TODO update benchmark_autosave
389
        return ElasticsearchStorage(*parse_elasticsearch_storage(storage[len("elasticsearch+"):]), **kwargs)
390
    else:
391
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
392
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
393
394
395
def time_unit(value):
396
    if value < 1e-6:
397
        return "n", 1e9
398
    elif value < 1e-3:
399
        return "u", 1e6
400
    elif value < 1:
401
        return "m", 1e3
402
    else:
403
        return "", 1.
404
405
406
def format_time(value):
407
    unit, adjustment = time_unit(value)
408
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
409
410
411
class cached_property(object):
412
    def __init__(self, func):
413
        self.__doc__ = getattr(func, '__doc__')
414
        self.func = func
415
416
    def __get__(self, obj, cls):
417
        if obj is None:
418
            return self
419
        value = obj.__dict__[self.func.__name__] = self.func(obj)
420
        return value
421
422
423
def funcname(f):
424
    try:
425
        if isinstance(f, partial):
426
            return f.func.__name__
427
        else:
428
            return f.__name__
429
    except AttributeError:
430
        return str(f)
431
432
433
def clonefunc(f):
434
    """Deep clone the given function to create a new one.
435
436
    By default, the PyPy JIT specializes the assembler based on f.__code__:
437
    clonefunc makes sure that you will get a new function with a **different**
438
    __code__, so that PyPy will produce independent assembler. This is useful
439
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
440
    apples to apples.
441
442
    Use it with caution: if abused, this might easily produce an explosion of
443
    produced assembler.
444
445
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
446
    """
447
448
    # first of all, we clone the code object
449
    try:
450
        co = f.__code__
451
        if PY3:
452
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
453
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
454
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
455
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
456
        else:
457
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
458
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
459
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
460
        #
461
        # then, we clone the function itself, using the new co2
462
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
463
    except AttributeError:
464
        return f
465
466
467
def format_dict(obj):
468
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
469
470
471
class SafeJSONEncoder(json.JSONEncoder):
472
    def default(self, o):
473
        return "UNSERIALIZABLE[%r]" % o
474
475
476
def safe_dumps(obj, **kwargs):
477
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
478
479
480
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
481
    total = len(iterable)
482
483
    def progress_reporting_wrapper():
484
        for pos, item in enumerate(iterable):
485
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
486
            terminal_reporter.rewrite(string, black=True, bold=True)
487
            yield string, item
488
    return progress_reporting_wrapper()
489
490
491
def report_noprogress(iterable, *args, **kwargs):
492
    for pos, item in enumerate(iterable):
493
        yield "", item
494
495
496
def slugify(name):
497
    for c in "\/:*?<>| ":
498
        name = name.replace(c, '_').replace('__', '_')
499
    return name
500
501
502
def commonpath(paths):
503
    """Given a sequence of path names, returns the longest common sub-path."""
504
505
    if not paths:
506
        raise ValueError('commonpath() arg is an empty sequence')
507
508
    if isinstance(paths[0], bytes):
509
        sep = b'\\'
510
        altsep = b'/'
511
        curdir = b'.'
512
    else:
513
        sep = '\\'
514
        altsep = '/'
515
        curdir = '.'
516
517
    try:
518
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
519
        split_paths = [p.split(sep) for d, p in drivesplits]
520
521
        try:
522
            isabs, = set(p[:1] == sep for d, p in drivesplits)
523
        except ValueError:
524
            raise ValueError("Can't mix absolute and relative paths")
525
526
        # Check that all drive letters or UNC paths match. The check is made only
527
        # now otherwise type errors for mixing strings and bytes would not be
528
        # caught.
529
        if len(set(d for d, p in drivesplits)) != 1:
530
            raise ValueError("Paths don't have the same drive")
531
532
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
533
        common = path.split(sep)
534
        common = [c for c in common if c and c != curdir]
535
536
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
537
        s1 = min(split_paths)
538
        s2 = max(split_paths)
539
        for i, c in enumerate(s1):
540
            if c != s2[i]:
541
                common = common[:i]
542
                break
543
        else:
544
            common = common[:len(s1)]
545
546
        prefix = drive + sep if isabs else drive
547
        return prefix + sep.join(common)
548
    except (TypeError, AttributeError):
549
        genericpath._check_arg_types('commonpath', *paths)
550
        raise
551
552
553
def get_cprofile_functions(stats):
554
    """
555
    Convert pstats structure to list of sorted dicts about each function.
556
    """
557
    result = []
558
    # this assumes that you run py.test from project root dir
559
    project_dir_parent = os.path.dirname(os.getcwd())
560
561
    for function_info, run_info in stats.stats.items():
562
        file_path = function_info[0]
563
        if file_path.startswith(project_dir_parent):
564
            file_path = file_path[len(project_dir_parent):].lstrip('/')
565
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
566
567
        # if the function is recursive write number of 'total calls/primitive calls'
568
        if run_info[0] == run_info[1]:
569
            calls = str(run_info[0])
570
        else:
571
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
572
573
        result.append(dict(ncalls_recursion=calls,
574
                           ncalls=run_info[1],
575
                           tottime=run_info[2],
576
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
577
                           cumtime=run_info[3],
578
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
579
                           function_name=function_name))
580
581
    return result
582
583
584
def is_list_like(value):
585
    """Return whether value is an iterable but not a mapping / string"""
586
    return isinstance(value, Iterable) and not isinstance(value, (base, dict))
587