Completed
Push — master ( b996cf...bbb940 )
by Ionel Cristian
9s
created

_parse_hosts()   C

Complexity

Conditions 7

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 7
c 2
b 0
f 0
dl 0
loc 22
rs 5.7894
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import netrc
8
import ntpath
9
import os
10
import platform
11
import re
12
import subprocess
13
import sys
14
import types
15
from datetime import datetime
16
from decimal import Decimal
17
from functools import partial
18
from os.path import basename
19
from os.path import dirname
20
from os.path import exists
21
from os.path import join
22
from os.path import split
23
24
from .compat import PY3
25
26
try:
27
    from urllib.parse import urlparse, parse_qs
28
except ImportError:
29
    from urlparse import urlparse, parse_qs
30
31
try:
32
    from subprocess import check_output, CalledProcessError
33
except ImportError:
34
    class CalledProcessError(subprocess.CalledProcessError):
35
        def __init__(self, returncode, cmd, output=None):
36
            super(CalledProcessError, self).__init__(returncode, cmd)
37
            self.output = output
38
39
    def check_output(*popenargs, **kwargs):
40
        if 'stdout' in kwargs:
41
            raise ValueError('stdout argument not allowed, it will be overridden.')
42
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
43
        output, unused_err = process.communicate()
44
        retcode = process.poll()
45
        if retcode:
46
            cmd = kwargs.get("args")
47
            if cmd is None:
48
                cmd = popenargs[0]
49
            raise CalledProcessError(retcode, cmd, output)
50
        return output
51
52
TIME_UNITS = {
53
    "": "Seconds",
54
    "m": "Miliseconds (ms)",
55
    "u": "Microseconds (us)",
56
    "n": "Nanoseconds (ns)"
57
}
58
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "rounds", "iterations"]
59
60
61
class SecondsDecimal(Decimal):
62
    def __float__(self):
63
        return float(super(SecondsDecimal, self).__str__())
64
65
    def __str__(self):
66
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
67
68
    @property
69
    def as_string(self):
70
        return super(SecondsDecimal, self).__str__()
71
72
73
class NameWrapper(object):
74
    def __init__(self, target):
75
        self.target = target
76
77
    def __str__(self):
78
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
79
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
80
        return name
81
82
    def __repr__(self):
83
        return "NameWrapper(%s)" % repr(self.target)
84
85
86
def get_tag(project_name=None):
87
    info = get_commit_info(project_name)
88
    parts = [info['id'], get_current_time()]
89
    if info['dirty']:
90
        parts.append("uncommited-changes")
91
    return "_".join(parts)
92
93
94
def get_machine_id():
95
    return "%s-%s-%s-%s" % (
96
        platform.system(),
97
        platform.python_implementation(),
98
        ".".join(platform.python_version_tuple()[:2]),
99
        platform.architecture()[0]
100
    )
101
102
103
class Fallback(object):
104
    def __init__(self, *exceptions):
105
        self.functions = []
106
        self.exceptions = exceptions
107
108
    def __call__(self, *args, **kwargs):
109
        for func in self.functions:
110
            try:
111
                return func(*args, **kwargs)
112
            except self.exceptions:
113
                continue
114
        else:
115
            raise NotImplementedError("Must have an fallback implementation for %s" % self.functions[0])
116
117
    def register(self, other):
118
        self.functions.append(other)
119
        return self
120
121
122
get_project_name = Fallback(IndexError, CalledProcessError)
123
124
125
@get_project_name.register
126
def get_project_name_git():
127
    project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
128
    if isinstance(project_address, bytes) and str != bytes:
129
        project_address = project_address.decode()
130
    project_name = [i for i in re.split(r'[/:\s]|\.git', project_address) if i][-1]
131
    return project_name.strip()
132
133
134
@get_project_name.register
135
def get_project_name_hg():
136
    project_address = check_output(['hg', 'path', 'default'])
137
    project_address = project_address.decode()
138
    project_name = project_address.split("/")[-1]
139
    return project_name.strip()
140
141
142
@get_project_name.register
143
def get_project_name_default():
144
    return basename(os.getcwd())
145
146
147
def in_any_parent(name, path=None):
148
    prev = None
149
    if not path:
150
        path = os.getcwd()
151
    while path and prev != path and not exists(join(path, name)):
152
        prev = path
153
        path = dirname(path)
154
    return exists(join(path, name))
155
156
157
def get_branch_info():
158
    def cmd(s):
159
        args = s.split()
160
        return check_output(args, stderr=subprocess.STDOUT, universal_newlines=True)
161
162
    try:
163
        if in_any_parent('.git'):
164
            branch = cmd('git rev-parse --abbrev-ref HEAD').strip()
165
            if branch == 'HEAD':
166
                return '(detached head)'
167
            return branch
168
        elif in_any_parent('.hg'):
169
            return cmd('hg branch').strip()
170
        else:
171
            return '(unknown vcs)'
172
    except subprocess.CalledProcessError as e:
173
        return '(error: %s)' % e.output.strip()
174
175
176
def get_commit_info(project_name=None):
177
    dirty = False
178
    commit = 'unversioned'
179
    commit_time = None
180
    project_name = project_name or get_project_name()
181
    branch = get_branch_info()
182
    try:
183
        if in_any_parent('.git'):
184
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
185
                                universal_newlines=True).strip()
186
            desc = desc.split('-')
187
            if desc[-1].strip() == 'dirty':
188
                dirty = True
189
                desc.pop()
190
            commit = desc[-1].strip('g')
191
            commit_time = check_output('git show -s --pretty=format:"%cI"'.split(),
192
                                       universal_newlines=True).strip().strip('"')
193
        elif in_any_parent('.hg'):
194
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
195
            if desc[-1] == '+':
196
                dirty = True
197
            commit = desc.strip('+')
198
            commit_time = check_output('hg tip --template "{date|rfc3339date}"'.split(),
199
                                       universal_newlines=True).strip().strip('"')
200
        return {
201
            'id': commit,
202
            'time': commit_time,
203
            'dirty': dirty,
204
            'project': project_name,
205
            'branch': branch,
206
        }
207
    except Exception as exc:
208
        return {
209
            'id': 'unknown',
210
            'time': None,
211
            'dirty': dirty,
212
            'error': repr(exc),
213
            'project': project_name,
214
            'branch': branch,
215
        }
216
217
218
def get_current_time():
219
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
220
221
222
def first_or_value(obj, value):
223
    if obj:
224
        value, = obj
225
226
    return value
227
228
229
def short_filename(path, machine_id=None):
230
    parts = []
231
    try:
232
        last = len(path.parts) - 1
233
    except AttributeError:
234
        return str(path)
235
    for pos, part in enumerate(path.parts):
236
        if not pos and part == machine_id:
237
            continue
238
        if pos == last:
239
            part = part.rsplit('.', 1)[0]
240
            # if len(part) > 16:
241
            #     part = "%.13s..." % part
242
        parts.append(part)
243
    return '/'.join(parts)
244
245
246
def load_timer(string):
247
    if "." not in string:
248
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
249
    mod, attr = string.rsplit(".", 1)
250
    if mod == 'pep418':
251
        if PY3:
252
            import time
253
            return NameWrapper(getattr(time, attr))
254
        else:
255
            from . import pep418
256
            return NameWrapper(getattr(pep418, attr))
257
    else:
258
        __import__(mod)
259
        mod = sys.modules[mod]
260
        return NameWrapper(getattr(mod, attr))
261
262
263
class RegressionCheck(object):
264
    def __init__(self, field, threshold):
265
        self.field = field
266
        self.threshold = threshold
267
268
    def fails(self, current, compared):
269
        val = self.compute(current, compared)
270
        if val > self.threshold:
271
            return "Field %r has failed %s: %.9f > %.9f" % (
272
                self.field, self.__class__.__name__, val, self.threshold
273
            )
274
275
276
class PercentageRegressionCheck(RegressionCheck):
277
    def compute(self, current, compared):
278
        val = compared[self.field]
279
        if not val:
280
            return float("inf")
281
        return current[self.field] / val * 100 - 100
282
283
284
class DifferenceRegressionCheck(RegressionCheck):
285
    def compute(self, current, compared):
286
        return current[self.field] - compared[self.field]
287
288
289
def parse_compare_fail(string,
290
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
291
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
292
                                      '0-9]+)?))$')):
293
    m = rex.match(string)
294
    if m:
295
        g = m.groupdict()
296
        if g['percentage']:
297
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
298
        elif g['difference']:
299
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
300
301
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
302
303
304
def parse_warmup(string):
305
    string = string.lower().strip()
306
    if string == "auto":
307
        return platform.python_implementation() == "PyPy"
308
    elif string in ["off", "false", "no"]:
309
        return False
310
    elif string in ["on", "true", "yes", ""]:
311
        return True
312
    else:
313
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
314
315
316
def name_formatter_short(bench):
317
    name = bench["name"]
318
    if bench["source"]:
319
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
320
    if name.startswith("test_"):
321
        name = name[5:]
322
    return name
323
324
325
def name_formatter_normal(bench):
326
    name = bench["name"]
327
    if bench["source"]:
328
        parts = bench["source"].split('/')
329
        parts[-1] = parts[-1][:12]
330
        name = "%s (%s)" % (name, '/'.join(parts))
331
    return name
332
333
334
def name_formatter_long(bench):
335
    if bench["source"]:
336
        return "%(fullname)s (%(source)s)" % bench
337
    else:
338
        return bench["fullname"]
339
340
341
NAME_FORMATTERS = {
342
    "short": name_formatter_short,
343
    "normal": name_formatter_normal,
344
    "long": name_formatter_long,
345
}
346
347
348
def parse_name_format(string):
349
    string = string.lower().strip()
350
    if string in NAME_FORMATTERS:
351
        return string
352
    else:
353
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
354
355
356
def parse_timer(string):
357
    return str(load_timer(string))
358
359
360
def parse_sort(string):
361
    string = string.lower().strip()
362
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
363
        raise argparse.ArgumentTypeError(
364
            "Unacceptable value: %r. "
365
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
366
            "'stddev', 'name', 'fullname'." % string)
367
    return string
368
369
370
def parse_columns(string):
371
    columns = [str.strip(s) for s in string.lower().split(',')]
372
    invalid = set(columns) - set(ALLOWED_COLUMNS)
373
    if invalid:
374
        # there are extra items in columns!
375
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
376
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
377
        raise argparse.ArgumentTypeError(msg)
378
    return columns
379
380
381
def parse_rounds(string):
382
    try:
383
        value = int(string)
384
    except ValueError as exc:
385
        raise argparse.ArgumentTypeError(exc)
386
    else:
387
        if value < 1:
388
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
389
        return value
390
391
392
def parse_seconds(string):
393
    try:
394
        return SecondsDecimal(string).as_string
395
    except Exception as exc:
396
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
397
398
399
def parse_save(string):
400
    if not string:
401
        raise argparse.ArgumentTypeError("Can't be empty.")
402
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
403
    if illegal:
404
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
405
    return string
406
407
408
def _parse_hosts(storage_url, netrc_file):
409
410
    # load creds from netrc file
411
    path = os.path.expanduser(netrc_file)
412
    creds = None
413
    if netrc_file and os.path.isfile(path):
414
        creds = netrc.netrc(path)
415
416
    # add creds to urls
417
    urls = []
418
    for netloc in storage_url.netloc.split(','):
419
        auth = ""
420
        if creds and '@' not in netloc:
421
            host = netloc.split(':').pop(0)
422
            res = creds.authenticators(host)
423
            if res:
424
                user, _, secret = res
425
                auth = "{user}:{secret}@".format(user=user, secret=secret)
426
        url = "{scheme}://{auth}{netloc}".format(scheme=storage_url.scheme,
427
                                                 netloc=netloc, auth=auth)
428
        urls.append(url)
429
    return urls
430
431
432
def parse_elasticsearch_storage(string, default_index="benchmark",
433
                                default_doctype="benchmark", netrc_file=''):
434
    storage_url = urlparse(string)
435
    hosts = _parse_hosts(storage_url, netrc_file)
436
    index = default_index
437
    doctype = default_doctype
438
    if storage_url.path and storage_url.path != "/":
439
        splitted = storage_url.path.strip("/").split("/")
440
        index = splitted[0]
441
        if len(splitted) >= 2:
442
            doctype = splitted[1]
443
    query = parse_qs(storage_url.query)
444
    try:
445
        project_name = query["project_name"][0]
446
    except KeyError:
447
        project_name = get_project_name()
448
    return hosts, index, doctype, project_name
449
450
451
def load_storage(storage, **kwargs):
452
    if "://" not in storage:
453
        storage = "file://" + storage
454
    netrc_file = kwargs.pop('netrc')  # only used by elasticsearch storage
455
    if storage.startswith("file://"):
456
        from .storage.file import FileStorage
457
        return FileStorage(storage[len("file://"):], **kwargs)
458
    elif storage.startswith("elasticsearch+"):
459
        from .storage.elasticsearch import ElasticsearchStorage
460
        # TODO update benchmark_autosave
461
        args = parse_elasticsearch_storage(storage[len("elasticsearch+"):],
462
                                           netrc_file=netrc_file)
463
        return ElasticsearchStorage(*args, **kwargs)
464
    else:
465
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
466
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
467
468
469
def time_unit(value):
470
    if value < 1e-6:
471
        return "n", 1e9
472
    elif value < 1e-3:
473
        return "u", 1e6
474
    elif value < 1:
475
        return "m", 1e3
476
    else:
477
        return "", 1.
478
479
480
def format_time(value):
481
    unit, adjustment = time_unit(value)
482
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
483
484
485
class cached_property(object):
486
    def __init__(self, func):
487
        self.__doc__ = getattr(func, '__doc__')
488
        self.func = func
489
490
    def __get__(self, obj, cls):
491
        if obj is None:
492
            return self
493
        value = obj.__dict__[self.func.__name__] = self.func(obj)
494
        return value
495
496
497
def funcname(f):
498
    try:
499
        if isinstance(f, partial):
500
            return f.func.__name__
501
        else:
502
            return f.__name__
503
    except AttributeError:
504
        return str(f)
505
506
507
def clonefunc(f):
508
    """Deep clone the given function to create a new one.
509
510
    By default, the PyPy JIT specializes the assembler based on f.__code__:
511
    clonefunc makes sure that you will get a new function with a **different**
512
    __code__, so that PyPy will produce independent assembler. This is useful
513
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
514
    apples to apples.
515
516
    Use it with caution: if abused, this might easily produce an explosion of
517
    produced assembler.
518
519
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
520
    """
521
522
    # first of all, we clone the code object
523
    try:
524
        co = f.__code__
525
        if PY3:
526
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
527
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
528
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
529
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
530
        else:
531
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
532
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
533
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
534
        #
535
        # then, we clone the function itself, using the new co2
536
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
537
    except AttributeError:
538
        return f
539
540
541
def format_dict(obj):
542
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
543
544
545
class SafeJSONEncoder(json.JSONEncoder):
546
    def default(self, o):
547
        return "UNSERIALIZABLE[%r]" % o
548
549
550
def safe_dumps(obj, **kwargs):
551
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
552
553
554
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
555
    total = len(iterable)
556
557
    def progress_reporting_wrapper():
558
        for pos, item in enumerate(iterable):
559
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
560
            terminal_reporter.rewrite(string, black=True, bold=True)
561
            yield string, item
562
563
    return progress_reporting_wrapper()
564
565
566
def report_noprogress(iterable, *args, **kwargs):
567
    for pos, item in enumerate(iterable):
568
        yield "", item
569
570
571
def slugify(name):
572
    for c in "\/:*?<>| ":
573
        name = name.replace(c, '_').replace('__', '_')
574
    return name
575
576
577
def commonpath(paths):
578
    """Given a sequence of path names, returns the longest common sub-path."""
579
580
    if not paths:
581
        raise ValueError('commonpath() arg is an empty sequence')
582
583
    if isinstance(paths[0], bytes):
584
        sep = b'\\'
585
        altsep = b'/'
586
        curdir = b'.'
587
    else:
588
        sep = '\\'
589
        altsep = '/'
590
        curdir = '.'
591
592
    try:
593
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
594
        split_paths = [p.split(sep) for d, p in drivesplits]
595
596
        try:
597
            isabs, = set(p[:1] == sep for d, p in drivesplits)
598
        except ValueError:
599
            raise ValueError("Can't mix absolute and relative paths")
600
601
        # Check that all drive letters or UNC paths match. The check is made only
602
        # now otherwise type errors for mixing strings and bytes would not be
603
        # caught.
604
        if len(set(d for d, p in drivesplits)) != 1:
605
            raise ValueError("Paths don't have the same drive")
606
607
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
608
        common = path.split(sep)
609
        common = [c for c in common if c and c != curdir]
610
611
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
612
        s1 = min(split_paths)
613
        s2 = max(split_paths)
614
        for i, c in enumerate(s1):
615
            if c != s2[i]:
616
                common = common[:i]
617
                break
618
        else:
619
            common = common[:len(s1)]
620
621
        prefix = drive + sep if isabs else drive
622
        return prefix + sep.join(common)
623
    except (TypeError, AttributeError):
624
        genericpath._check_arg_types('commonpath', *paths)
625
        raise
626
627
628
def get_cprofile_functions(stats):
629
    """
630
    Convert pstats structure to list of sorted dicts about each function.
631
    """
632
    result = []
633
    # this assumes that you run py.test from project root dir
634
    project_dir_parent = dirname(os.getcwd())
635
636
    for function_info, run_info in stats.stats.items():
637
        file_path = function_info[0]
638
        if file_path.startswith(project_dir_parent):
639
            file_path = file_path[len(project_dir_parent):].lstrip('/')
640
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
641
642
        # if the function is recursive write number of 'total calls/primitive calls'
643
        if run_info[0] == run_info[1]:
644
            calls = str(run_info[0])
645
        else:
646
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
647
648
        result.append(dict(ncalls_recursion=calls,
649
                           ncalls=run_info[1],
650
                           tottime=run_info[2],
651
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
652
                           cumtime=run_info[3],
653
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
654
                           function_name=function_name))
655
656
    return result
657