Completed
Push — master ( b341e8...9d6e1e )
by Ionel Cristian
01:10
created

get_project_name_default()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 3
rs 10
1
from __future__ import division
2
from __future__ import print_function
3
4
import argparse
5
import genericpath
6
import json
7
import ntpath
8
import os
9
import platform
10
import re
11
import subprocess
12
import sys
13
import types
14
from datetime import datetime
15
from decimal import Decimal
16
from functools import partial
17
from os.path import basename
18
from os.path import dirname
19
from os.path import exists
20
from os.path import join
21
from os.path import split
22
23
from .compat import PY3
24
25
try:
26
    from urllib.parse import urlparse, parse_qs
27
except ImportError:
28
    from urlparse import urlparse, parse_qs
29
30
try:
31
    from subprocess import check_output, CalledProcessError
32
except ImportError:
33
    class CalledProcessError(subprocess.CalledProcessError):
34
        def __init__(self, returncode, cmd, output=None):
35
            super(CalledProcessError, self).__init__(returncode, cmd)
36
            self.output = output
37
38
    def check_output(*popenargs, **kwargs):
39
        if 'stdout' in kwargs:
40
            raise ValueError('stdout argument not allowed, it will be overridden.')
41
        process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
42
        output, unused_err = process.communicate()
43
        retcode = process.poll()
44
        if retcode:
45
            cmd = kwargs.get("args")
46
            if cmd is None:
47
                cmd = popenargs[0]
48
            raise CalledProcessError(retcode, cmd, output)
49
        return output
50
51
TIME_UNITS = {
52
    "": "Seconds",
53
    "m": "Miliseconds (ms)",
54
    "u": "Microseconds (us)",
55
    "n": "Nanoseconds (ns)"
56
}
57
ALLOWED_COLUMNS = ["min", "max", "mean", "stddev", "median", "iqr", "outliers", "rounds", "iterations"]
58
59
60
class SecondsDecimal(Decimal):
61
    def __float__(self):
62
        return float(super(SecondsDecimal, self).__str__())
63
64
    def __str__(self):
65
        return "{0}s".format(format_time(float(super(SecondsDecimal, self).__str__())))
66
67
    @property
68
    def as_string(self):
69
        return super(SecondsDecimal, self).__str__()
70
71
72
class NameWrapper(object):
73
    def __init__(self, target):
74
        self.target = target
75
76
    def __str__(self):
77
        name = self.target.__module__ + "." if hasattr(self.target, '__module__') else ""
78
        name += self.target.__name__ if hasattr(self.target, '__name__') else repr(self.target)
79
        return name
80
81
    def __repr__(self):
82
        return "NameWrapper(%s)" % repr(self.target)
83
84
85
def get_tag(project_name=None):
86
    info = get_commit_info(project_name)
87
    parts = [info['id'], get_current_time()]
88
    if info['dirty']:
89
        parts.append("uncommited-changes")
90
    return "_".join(parts)
91
92
93
def get_machine_id():
94
    return "%s-%s-%s-%s" % (
95
        platform.system(),
96
        platform.python_implementation(),
97
        ".".join(platform.python_version_tuple()[:2]),
98
        platform.architecture()[0]
99
    )
100
101
102
class Fallback(object):
103
    def __init__(self, *exceptions):
104
        self.functions = []
105
        self.exceptions = exceptions
106
107
    def __call__(self, *args, **kwargs):
108
        for func in self.functions:
109
            try:
110
                return func(*args, **kwargs)
111
            except self.exceptions:
112
                continue
113
        else:
114
            raise NotImplementedError("Must have an fallback implementation for %s" % self.functions[0])
115
116
    def register(self, other):
117
        self.functions.append(other)
118
        return self
119
120
121
get_project_name = Fallback(IndexError, CalledProcessError)
122
123
124
@get_project_name.register
125
def get_project_name_git():
126
    project_address = check_output(['git', 'config', '--local', 'remote.origin.url'])
127
    if isinstance(project_address, bytes) and str != bytes:
128
        project_address = project_address.decode()
129
    project_name = [i for i in re.split(r'[/:\s]|\.git', project_address) if i][-1]
130
    return project_name.strip()
131
132
133
@get_project_name.register
134
def get_project_name_hg():
135
    project_address = check_output(['hg', 'path', 'default'])
136
    project_address = project_address.decode()
137
    project_name = project_address.split("/")[-1]
138
    return project_name.strip()
139
140
141
@get_project_name.register
142
def get_project_name_default():
143
    return basename(os.getcwd())
144
145
146
def in_any_parent(name, path=None):
147
    prev = None
148
    if not path:
149
        path = os.getcwd()
150
    while path and prev != path and not exists(join(path, name)):
151
        prev = path
152
        path = dirname(path)
153
    return exists(join(path, name))
154
155
156
def get_branch_info():
157
    def cmd(s):
158
        args = s.split()
159
        return check_output(args, stderr=subprocess.STDOUT, universal_newlines=True)
160
161
    try:
162
        if in_any_parent('.git'):
163
            branch = cmd('git rev-parse --abbrev-ref HEAD').strip()
164
            if branch == 'HEAD':
165
                return '(detached head)'
166
            return branch
167
        elif in_any_parent('.hg'):
168
            return cmd('hg branch').strip()
169
        else:
170
            return '(unknown vcs)'
171
    except subprocess.CalledProcessError as e:
172
        return '(error: %s)' % e.output.strip()
173
174
175
def get_commit_info(project_name=None):
176
    dirty = False
177
    commit = 'unversioned'
178
    project_name = project_name or get_project_name()
179
    branch = get_branch_info()
180
    try:
181
        if in_any_parent('.git'):
182
            desc = check_output('git describe --dirty --always --long --abbrev=40'.split(),
183
                                universal_newlines=True).strip()
184
            desc = desc.split('-')
185
            if desc[-1].strip() == 'dirty':
186
                dirty = True
187
                desc.pop()
188
            commit = desc[-1].strip('g')
189
        elif in_any_parent('.hg'):
190
            desc = check_output('hg id --id --debug'.split(), universal_newlines=True).strip()
191
            if desc[-1] == '+':
192
                dirty = True
193
            commit = desc.strip('+')
194
        return {
195
            'id': commit,
196
            'dirty': dirty,
197
            'project': project_name,
198
            'branch': branch,
199
        }
200
    except Exception as exc:
201
        return {
202
            'id': 'unknown',
203
            'dirty': dirty,
204
            'error': repr(exc),
205
            'project': project_name,
206
            'branch': branch,
207
        }
208
209
210
def get_current_time():
211
    return datetime.utcnow().strftime("%Y%m%d_%H%M%S")
212
213
214
def first_or_value(obj, value):
215
    if obj:
216
        value, = obj
217
218
    return value
219
220
221
def short_filename(path, machine_id=None):
222
    parts = []
223
    try:
224
        last = len(path.parts) - 1
225
    except AttributeError:
226
        return str(path)
227
    for pos, part in enumerate(path.parts):
228
        if not pos and part == machine_id:
229
            continue
230
        if pos == last:
231
            part = part.rsplit('.', 1)[0]
232
            # if len(part) > 16:
233
            #     part = "%.13s..." % part
234
        parts.append(part)
235
    return '/'.join(parts)
236
237
238
def load_timer(string):
239
    if "." not in string:
240
        raise argparse.ArgumentTypeError("Value for --benchmark-timer must be in dotted form. Eg: 'module.attr'.")
241
    mod, attr = string.rsplit(".", 1)
242
    if mod == 'pep418':
243
        if PY3:
244
            import time
245
            return NameWrapper(getattr(time, attr))
246
        else:
247
            from . import pep418
248
            return NameWrapper(getattr(pep418, attr))
249
    else:
250
        __import__(mod)
251
        mod = sys.modules[mod]
252
        return NameWrapper(getattr(mod, attr))
253
254
255
class RegressionCheck(object):
256
    def __init__(self, field, threshold):
257
        self.field = field
258
        self.threshold = threshold
259
260
    def fails(self, current, compared):
261
        val = self.compute(current, compared)
262
        if val > self.threshold:
263
            return "Field %r has failed %s: %.9f > %.9f" % (
264
                self.field, self.__class__.__name__, val, self.threshold
265
            )
266
267
268
class PercentageRegressionCheck(RegressionCheck):
269
    def compute(self, current, compared):
270
        val = compared[self.field]
271
        if not val:
272
            return float("inf")
273
        return current[self.field] / val * 100 - 100
274
275
276
class DifferenceRegressionCheck(RegressionCheck):
277
    def compute(self, current, compared):
278
        return current[self.field] - compared[self.field]
279
280
281
def parse_compare_fail(string,
282
                       rex=re.compile('^(?P<field>min|max|mean|median|stddev|iqr):'
283
                                      '((?P<percentage>[0-9]?[0-9])%|(?P<difference>[0-9]*\.?[0-9]+([eE][-+]?['
284
                                      '0-9]+)?))$')):
285
    m = rex.match(string)
286
    if m:
287
        g = m.groupdict()
288
        if g['percentage']:
289
            return PercentageRegressionCheck(g['field'], int(g['percentage']))
290
        elif g['difference']:
291
            return DifferenceRegressionCheck(g['field'], float(g['difference']))
292
293
    raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
294
295
296
def parse_warmup(string):
297
    string = string.lower().strip()
298
    if string == "auto":
299
        return platform.python_implementation() == "PyPy"
300
    elif string in ["off", "false", "no"]:
301
        return False
302
    elif string in ["on", "true", "yes", ""]:
303
        return True
304
    else:
305
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
306
307
308
def name_formatter_short(bench):
309
    name = bench["name"]
310
    if bench["source"]:
311
        name = "%s (%.4s)" % (name, split(bench["source"])[-1])
312
    if name.startswith("test_"):
313
        name = name[5:]
314
    return name
315
316
317
def name_formatter_normal(bench):
318
    name = bench["name"]
319
    if bench["source"]:
320
        parts = bench["source"].split('/')
321
        parts[-1] = parts[-1][:12]
322
        name = "%s (%s)" % (name, '/'.join(parts))
323
    return name
324
325
326
def name_formatter_long(bench):
327
    if bench["source"]:
328
        return "%(fullname)s (%(source)s)" % bench
329
    else:
330
        return bench["fullname"]
331
332
333
NAME_FORMATTERS = {
334
    "short": name_formatter_short,
335
    "normal": name_formatter_normal,
336
    "long": name_formatter_long,
337
}
338
339
340
def parse_name_format(string):
341
    string = string.lower().strip()
342
    if string in NAME_FORMATTERS:
343
        return string
344
    else:
345
        raise argparse.ArgumentTypeError("Could not parse value: %r." % string)
346
347
348
def parse_timer(string):
349
    return str(load_timer(string))
350
351
352
def parse_sort(string):
353
    string = string.lower().strip()
354
    if string not in ("min", "max", "mean", "stddev", "name", "fullname"):
355
        raise argparse.ArgumentTypeError(
356
            "Unacceptable value: %r. "
357
            "Value for --benchmark-sort must be one of: 'min', 'max', 'mean', "
358
            "'stddev', 'name', 'fullname'." % string)
359
    return string
360
361
362
def parse_columns(string):
363
    columns = [str.strip(s) for s in string.lower().split(',')]
364
    invalid = set(columns) - set(ALLOWED_COLUMNS)
365
    if invalid:
366
        # there are extra items in columns!
367
        msg = "Invalid column name(s): %s. " % ', '.join(invalid)
368
        msg += "The only valid column names are: %s" % ', '.join(ALLOWED_COLUMNS)
369
        raise argparse.ArgumentTypeError(msg)
370
    return columns
371
372
373
def parse_rounds(string):
374
    try:
375
        value = int(string)
376
    except ValueError as exc:
377
        raise argparse.ArgumentTypeError(exc)
378
    else:
379
        if value < 1:
380
            raise argparse.ArgumentTypeError("Value for --benchmark-rounds must be at least 1.")
381
        return value
382
383
384
def parse_seconds(string):
385
    try:
386
        return SecondsDecimal(string).as_string
387
    except Exception as exc:
388
        raise argparse.ArgumentTypeError("Invalid decimal value %r: %r" % (string, exc))
389
390
391
def parse_save(string):
392
    if not string:
393
        raise argparse.ArgumentTypeError("Can't be empty.")
394
    illegal = ''.join(c for c in r"\/:*?<>|" if c in string)
395
    if illegal:
396
        raise argparse.ArgumentTypeError("Must not contain any of these characters: /:*?<>|\\ (it has %r)" % illegal)
397
    return string
398
399
400
def parse_elasticsearch_storage(string, default_index="benchmark", default_doctype="benchmark"):
401
    storage_url = urlparse(string)
402
    hosts = ["{scheme}://{netloc}".format(scheme=storage_url.scheme, netloc=netloc) for netloc in
403
             storage_url.netloc.split(',')]
404
    index = default_index
405
    doctype = default_doctype
406
    if storage_url.path and storage_url.path != "/":
407
        splitted = storage_url.path.strip("/").split("/")
408
        index = splitted[0]
409
        if len(splitted) >= 2:
410
            doctype = splitted[1]
411
    query = parse_qs(storage_url.query)
412
    try:
413
        project_name = query["project_name"][0]
414
    except KeyError:
415
        project_name = get_project_name()
416
    return hosts, index, doctype, project_name
417
418
419
def load_storage(storage, **kwargs):
420
    if "://" not in storage:
421
        storage = "file://" + storage
422
    if storage.startswith("file://"):
423
        from .storage.file import FileStorage
424
        return FileStorage(storage[len("file://"):], **kwargs)
425
    elif storage.startswith("elasticsearch+"):
426
        from .storage.elasticsearch import ElasticsearchStorage
427
        # TODO update benchmark_autosave
428
        return ElasticsearchStorage(*parse_elasticsearch_storage(storage[len("elasticsearch+"):]), **kwargs)
429
    else:
430
        raise argparse.ArgumentTypeError("Storage must be in form of file://path or "
431
                                         "elasticsearch+http[s]://host1,host2/index/doctype")
432
433
434
def time_unit(value):
435
    if value < 1e-6:
436
        return "n", 1e9
437
    elif value < 1e-3:
438
        return "u", 1e6
439
    elif value < 1:
440
        return "m", 1e3
441
    else:
442
        return "", 1.
443
444
445
def format_time(value):
446
    unit, adjustment = time_unit(value)
447
    return "{0:.2f}{1:s}".format(value * adjustment, unit)
448
449
450
class cached_property(object):
451
    def __init__(self, func):
452
        self.__doc__ = getattr(func, '__doc__')
453
        self.func = func
454
455
    def __get__(self, obj, cls):
456
        if obj is None:
457
            return self
458
        value = obj.__dict__[self.func.__name__] = self.func(obj)
459
        return value
460
461
462
def funcname(f):
463
    try:
464
        if isinstance(f, partial):
465
            return f.func.__name__
466
        else:
467
            return f.__name__
468
    except AttributeError:
469
        return str(f)
470
471
472
def clonefunc(f):
473
    """Deep clone the given function to create a new one.
474
475
    By default, the PyPy JIT specializes the assembler based on f.__code__:
476
    clonefunc makes sure that you will get a new function with a **different**
477
    __code__, so that PyPy will produce independent assembler. This is useful
478
    e.g. for benchmarks and microbenchmarks, so you can make sure to compare
479
    apples to apples.
480
481
    Use it with caution: if abused, this might easily produce an explosion of
482
    produced assembler.
483
484
    from: https://bitbucket.org/antocuni/pypytools/src/tip/pypytools/util.py?at=default
485
    """
486
487
    # first of all, we clone the code object
488
    try:
489
        co = f.__code__
490
        if PY3:
491
            co2 = types.CodeType(co.co_argcount, co.co_kwonlyargcount,
492
                                 co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
493
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
494
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
495
        else:
496
            co2 = types.CodeType(co.co_argcount, co.co_nlocals, co.co_stacksize, co.co_flags, co.co_code,
497
                                 co.co_consts, co.co_names, co.co_varnames, co.co_filename, co.co_name,
498
                                 co.co_firstlineno, co.co_lnotab, co.co_freevars, co.co_cellvars)
499
        #
500
        # then, we clone the function itself, using the new co2
501
        return types.FunctionType(co2, f.__globals__, f.__name__, f.__defaults__, f.__closure__)
502
    except AttributeError:
503
        return f
504
505
506
def format_dict(obj):
507
    return "{%s}" % ", ".join("%s: %s" % (k, json.dumps(v)) for k, v in sorted(obj.items()))
508
509
510
class SafeJSONEncoder(json.JSONEncoder):
511
    def default(self, o):
512
        return "UNSERIALIZABLE[%r]" % o
513
514
515
def safe_dumps(obj, **kwargs):
516
    return json.dumps(obj, cls=SafeJSONEncoder, **kwargs)
517
518
519
def report_progress(iterable, terminal_reporter, format_string, **kwargs):
520
    total = len(iterable)
521
522
    def progress_reporting_wrapper():
523
        for pos, item in enumerate(iterable):
524
            string = format_string.format(pos=pos + 1, total=total, value=item, **kwargs)
525
            terminal_reporter.rewrite(string, black=True, bold=True)
526
            yield string, item
527
528
    return progress_reporting_wrapper()
529
530
531
def report_noprogress(iterable, *args, **kwargs):
532
    for pos, item in enumerate(iterable):
533
        yield "", item
534
535
536
def slugify(name):
537
    for c in "\/:*?<>| ":
538
        name = name.replace(c, '_').replace('__', '_')
539
    return name
540
541
542
def commonpath(paths):
543
    """Given a sequence of path names, returns the longest common sub-path."""
544
545
    if not paths:
546
        raise ValueError('commonpath() arg is an empty sequence')
547
548
    if isinstance(paths[0], bytes):
549
        sep = b'\\'
550
        altsep = b'/'
551
        curdir = b'.'
552
    else:
553
        sep = '\\'
554
        altsep = '/'
555
        curdir = '.'
556
557
    try:
558
        drivesplits = [ntpath.splitdrive(p.replace(altsep, sep).lower()) for p in paths]
559
        split_paths = [p.split(sep) for d, p in drivesplits]
560
561
        try:
562
            isabs, = set(p[:1] == sep for d, p in drivesplits)
563
        except ValueError:
564
            raise ValueError("Can't mix absolute and relative paths")
565
566
        # Check that all drive letters or UNC paths match. The check is made only
567
        # now otherwise type errors for mixing strings and bytes would not be
568
        # caught.
569
        if len(set(d for d, p in drivesplits)) != 1:
570
            raise ValueError("Paths don't have the same drive")
571
572
        drive, path = ntpath.splitdrive(paths[0].replace(altsep, sep))
573
        common = path.split(sep)
574
        common = [c for c in common if c and c != curdir]
575
576
        split_paths = [[c for c in s if c and c != curdir] for s in split_paths]
577
        s1 = min(split_paths)
578
        s2 = max(split_paths)
579
        for i, c in enumerate(s1):
580
            if c != s2[i]:
581
                common = common[:i]
582
                break
583
        else:
584
            common = common[:len(s1)]
585
586
        prefix = drive + sep if isabs else drive
587
        return prefix + sep.join(common)
588
    except (TypeError, AttributeError):
589
        genericpath._check_arg_types('commonpath', *paths)
590
        raise
591
592
593
def get_cprofile_functions(stats):
594
    """
595
    Convert pstats structure to list of sorted dicts about each function.
596
    """
597
    result = []
598
    # this assumes that you run py.test from project root dir
599
    project_dir_parent = dirname(os.getcwd())
600
601
    for function_info, run_info in stats.stats.items():
602
        file_path = function_info[0]
603
        if file_path.startswith(project_dir_parent):
604
            file_path = file_path[len(project_dir_parent):].lstrip('/')
605
        function_name = '{0}:{1}({2})'.format(file_path, function_info[1], function_info[2])
606
607
        # if the function is recursive write number of 'total calls/primitive calls'
608
        if run_info[0] == run_info[1]:
609
            calls = str(run_info[0])
610
        else:
611
            calls = '{1}/{0}'.format(run_info[0], run_info[1])
612
613
        result.append(dict(ncalls_recursion=calls,
614
                           ncalls=run_info[1],
615
                           tottime=run_info[2],
616
                           tottime_per=run_info[2] / run_info[0] if run_info[0] > 0 else 0,
617
                           cumtime=run_info[3],
618
                           cumtime_per=run_info[3] / run_info[0] if run_info[0] > 0 else 0,
619
                           function_name=function_name))
620
621
    return result
622