PackageDependency.__repr__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 2
rs 10
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import functools
6
import hashlib
7
import os.path
8
import re
9
import sys
10
11
import pkg_resources
12
13
import sacred.optional as opt
14
from sacred import SETTINGS
15
from sacred.utils import is_subdir, iter_prefixes, basestring
16
17
MB = 1048576
18
MODULE_BLACKLIST = set(sys.builtin_module_names)
19
# sadly many builtins are missing from the above, so we list them manually:
20
MODULE_BLACKLIST |= {
21
    None, '__future__', '_abcoll', '_bootlocale', '_bsddb', '_bz2',
22
    '_codecs_cn', '_codecs_hk', '_codecs_iso2022', '_codecs_jp', '_codecs_kr',
23
    '_codecs_tw', '_collections_abc', '_compat_pickle', '_compression',
24
    '_crypt', '_csv', '_ctypes', '_ctypes_test', '_curses', '_curses_panel',
25
    '_dbm', '_decimal', '_dummy_thread', '_elementtree', '_gdbm', '_hashlib',
26
    '_hotshot', '_json', '_lsprof', '_LWPCookieJar', '_lzma', '_markupbase',
27
    '_MozillaCookieJar', '_multibytecodec', '_multiprocessing', '_opcode',
28
    '_osx_support', '_pydecimal', '_pyio', '_sitebuiltins', '_sqlite3',
29
    '_ssl', '_strptime', '_sysconfigdata', '_sysconfigdata_m',
30
    '_sysconfigdata_nd', '_testbuffer', '_testcapi', '_testimportmultiple',
31
    '_testmultiphase', '_threading_local', '_tkinter', '_weakrefset', 'abc',
32
    'aifc', 'antigravity', 'anydbm', 'argparse', 'ast', 'asynchat', 'asyncio',
33
    'asyncore', 'atexit', 'audiodev', 'audioop', 'base64', 'BaseHTTPServer',
34
    'Bastion', 'bdb', 'binhex', 'bisect', 'bsddb', 'bz2', 'calendar',
35
    'Canvas', 'CDROM', 'cgi', 'CGIHTTPServer', 'cgitb', 'chunk', 'cmath',
36
    'cmd', 'code', 'codecs', 'codeop', 'collections', 'colorsys', 'commands',
37
    'compileall', 'compiler', 'concurrent', 'ConfigParser', 'configparser',
38
    'contextlib', 'Cookie', 'cookielib', 'copy', 'copy_reg', 'copyreg',
39
    'cProfile', 'crypt', 'csv', 'ctypes', 'curses', 'datetime', 'dbhash',
40
    'dbm', 'decimal', 'Dialog', 'difflib', 'dircache', 'dis', 'distutils',
41
    'DLFCN', 'doctest', 'DocXMLRPCServer', 'dumbdbm', 'dummy_thread',
42
    'dummy_threading', 'easy_install', 'email', 'encodings', 'ensurepip',
43
    'enum', 'filecmp', 'FileDialog', 'fileinput', 'FixTk', 'fnmatch',
44
    'formatter', 'fpectl', 'fpformat', 'fractions', 'ftplib', 'functools',
45
    'future_builtins', 'genericpath', 'getopt', 'getpass', 'gettext', 'glob',
46
    'gzip', 'hashlib', 'heapq', 'hmac', 'hotshot', 'html', 'htmlentitydefs',
47
    'htmllib', 'HTMLParser', 'http', 'httplib', 'idlelib', 'ihooks',
48
    'imaplib', 'imghdr', 'imp', 'importlib', 'imputil', 'IN', 'inspect', 'io',
49
    'ipaddress', 'json', 'keyword', 'lib2to3', 'linecache', 'linuxaudiodev',
50
    'locale', 'logging', 'lzma', 'macpath', 'macurl2path', 'mailbox',
51
    'mailcap', 'markupbase', 'md5', 'mhlib', 'mimetools', 'mimetypes',
52
    'MimeWriter', 'mimify', 'mmap', 'modulefinder', 'multifile',
53
    'multiprocessing', 'mutex', 'netrc', 'new', 'nis', 'nntplib', 'ntpath',
54
    'nturl2path', 'numbers', 'opcode', 'operator', 'optparse', 'os',
55
    'os2emxpath', 'ossaudiodev', 'parser', 'pathlib', 'pdb', 'pickle',
56
    'pickletools', 'pip', 'pipes', 'pkg_resources', 'pkgutil', 'platform',
57
    'plistlib', 'popen2', 'poplib', 'posixfile', 'posixpath', 'pprint',
58
    'profile', 'pstats', 'pty', 'py_compile', 'pyclbr', 'pydoc', 'pydoc_data',
59
    'pyexpat', 'Queue', 'queue', 'quopri', 'random', 're', 'readline', 'repr',
60
    'reprlib', 'resource', 'rexec', 'rfc822', 'rlcompleter', 'robotparser',
61
    'runpy', 'sched', 'ScrolledText', 'selectors', 'sets', 'setuptools',
62
    'sgmllib', 'sha', 'shelve', 'shlex', 'shutil', 'signal', 'SimpleDialog',
63
    'SimpleHTTPServer', 'SimpleXMLRPCServer', 'site', 'sitecustomize',
64
    'smtpd', 'smtplib', 'sndhdr', 'socket', 'SocketServer', 'socketserver',
65
    'sqlite3', 'sre', 'sre_compile', 'sre_constants', 'sre_parse', 'ssl',
66
    'stat', 'statistics', 'statvfs', 'string', 'StringIO', 'stringold',
67
    'stringprep', 'struct', 'subprocess', 'sunau', 'sunaudio', 'symbol',
68
    'symtable', 'sysconfig', 'tabnanny', 'tarfile', 'telnetlib', 'tempfile',
69
    'termios', 'test', 'textwrap', 'this', 'threading', 'timeit', 'Tix',
70
    'tkColorChooser', 'tkCommonDialog', 'Tkconstants', 'Tkdnd',
71
    'tkFileDialog', 'tkFont', 'tkinter', 'Tkinter', 'tkMessageBox',
72
    'tkSimpleDialog', 'toaiff', 'token', 'tokenize', 'trace', 'traceback',
73
    'tracemalloc', 'ttk', 'tty', 'turtle', 'types', 'TYPES', 'typing',
74
    'unittest', 'urllib', 'urllib2', 'urlparse', 'user', 'UserDict',
75
    'UserList', 'UserString', 'uu', 'uuid', 'venv', 'warnings', 'wave',
76
    'weakref', 'webbrowser', 'wheel', 'whichdb', 'wsgiref', 'xdrlib', 'xml',
77
    'xmllib', 'xmlrpc', 'xmlrpclib', 'xxlimited', 'zipapp', 'zipfile'}
78
79
module = type(sys)
80
PEP440_VERSION_PATTERN = re.compile(r"""
81
^
82
(\d+!)?              # epoch
83
(\d[.\d]*(?<= \d))   # release
84
((?:[abc]|rc)\d+)?   # pre-release
85
(?:(\.post\d+))?     # post-release
86
(?:(\.dev\d+))?      # development release
87
$
88
""", flags=re.VERBOSE)
89
90
91
def get_py_file_if_possible(pyc_name):
92
    """Try to retrieve a X.py file for a given X.py[c] file."""
93
    if pyc_name.endswith('.py'):
94
        return pyc_name
95
    assert pyc_name.endswith('.pyc')
96
    non_compiled_file = pyc_name[:-1]
97
    if os.path.exists(non_compiled_file):
98
        return non_compiled_file
99
    return pyc_name
100
101
102
def get_digest(filename):
103
    """Compute the MD5 hash for a given file."""
104
    h = hashlib.md5()
105
    with open(filename, 'rb') as f:
106
        data = f.read(1 * MB)
107
        while data:
108
            h.update(data)
109
            data = f.read(1 * MB)
110
        return h.hexdigest()
111
112
113
def get_commit_if_possible(filename):
114
    """Try to retrieve VCS information for a given file.
115
116
    Currently only supports git using the gitpython package.
117
118
    Parameters
119
    ----------
120
    filename : str
121
122
    Returns
123
    -------
124
        path: str
125
            The base path of the repository
126
        commit: str
127
            The commit hash
128
        is_dirty: bool
129
            True if there are uncommitted changes in the repository
130
    """
131
    # git
132
    if opt.has_gitpython:
133
        from git import Repo, InvalidGitRepositoryError
134
        try:
135
            directory = os.path.dirname(filename)
136
            repo = Repo(directory, search_parent_directories=True)
137
            try:
138
                path = repo.remote().url
139
            except ValueError:
140
                path = 'git:/' + repo.working_dir
141
            is_dirty = repo.is_dirty()
142
            commit = repo.head.commit.hexsha
143
            return path, commit, is_dirty
144
        except (InvalidGitRepositoryError, ValueError):
145
            pass
146
    return None, None, None
147
148
149
@functools.total_ordering
150
class Source(object):
151
    def __init__(self, filename, digest, repo, commit, isdirty):
152
        self.filename = filename
153
        self.digest = digest
154
        self.repo = repo
155
        self.commit = commit
156
        self.is_dirty = isdirty
157
158
    @staticmethod
159
    def create(filename):
160
        if not filename or not os.path.exists(filename):
161
            raise ValueError('invalid filename or file not found "{}"'
162
                             .format(filename))
163
164
        main_file = get_py_file_if_possible(os.path.abspath(filename))
165
        repo, commit, is_dirty = get_commit_if_possible(main_file)
166
        return Source(main_file, get_digest(main_file), repo, commit, is_dirty)
167
168
    def to_json(self, base_dir=None):
169
        if base_dir:
170
            return os.path.relpath(self.filename, base_dir), self.digest
171
        else:
172
            return self.filename, self.digest
173
174
    def __hash__(self):
175
        return hash(self.filename)
176
177
    def __eq__(self, other):
178
        if isinstance(other, Source):
179
            return self.filename == other.filename
180
        elif isinstance(other, basestring):
181
            return self.filename == other
182
        else:
183
            return False
184
185
    def __le__(self, other):
186
        return self.filename.__le__(other.filename)
187
188
    def __repr__(self):
189
        return '<Source: {}>'.format(self.filename)
190
191
192
@functools.total_ordering
193
class PackageDependency(object):
194
    modname_to_dist = {}
195
196
    def __init__(self, name, version):
197
        self.name = name
198
        self.version = version
199
200
    def fill_missing_version(self):
201
        if self.version is not None:
202
            return
203
        dist = pkg_resources.working_set.by_key.get(self.name)
204
        self.version = dist.version if dist else None
205
206
    def to_json(self):
207
        return '{}=={}'.format(self.name, self.version or '<unknown>')
208
209
    def __hash__(self):
210
        return hash(self.name)
211
212
    def __eq__(self, other):
213
        if isinstance(other, PackageDependency):
214
            return self.name == other.name
215
        else:
216
            return False
217
218
    def __le__(self, other):
219
        return self.name.__le__(other.name)
220
221
    def __repr__(self):
222
        return '<PackageDependency: {}={}>'.format(self.name, self.version)
223
224
    @staticmethod
225
    def get_version_heuristic(mod):
226
        possible_version_attributes = ['__version__', 'VERSION', 'version']
227
        for vattr in possible_version_attributes:
228
            if hasattr(mod, vattr):
229
                version = getattr(mod, vattr)
230
                if isinstance(version, basestring) and \
231
                        PEP440_VERSION_PATTERN.match(version):
232
                    return version
233
                if isinstance(version, tuple):
234
                    version = '.'.join([str(n) for n in version])
235
                    if PEP440_VERSION_PATTERN.match(version):
236
                        return version
237
238
        return None
239
240
    @classmethod
241
    def create(cls, mod):
242
        if not cls.modname_to_dist:
243
            # some packagenames don't match the module names (e.g. PyYAML)
244
            # so we set up a dict to map from module name to package name
245
            for dist in pkg_resources.working_set:
246
                try:
247
                    toplevel_names = dist._get_metadata('top_level.txt')
248
                    for tln in toplevel_names:
249
                        cls.modname_to_dist[
250
                            tln] = dist.project_name, dist.version
251
                except:
252
                    pass
253
254
        # version = PackageDependency.get_version_heuristic(mod)
255
        name, version = cls.modname_to_dist.get(mod.__name__,
256
                                                (mod.__name__, None))
257
258
        return PackageDependency(name, version)
259
260
261
def splitall(path):
262
    """Split a path into a list of directory names (and optionally a filename).
263
264
    Parameters
265
    ----------
266
    path: str
267
        The path (absolute or relative).
268
269
    Returns
270
    -------
271
    allparts: list[str]
272
        List of directory names (and optionally a filename)
273
274
    Example
275
    -------
276
    "foo/bar/baz.py" => ["foo", "bar", "baz.py"]
277
    "/absolute/path.py" => ["/", "absolute", "baz.py"]
278
279
    Notes
280
    -----
281
    Credit to Trent Mick. Taken from
282
    https://www.safaribooksonline.com/library/view/python-cookbook/0596001673/ch04s16.html
283
    """
284
    allparts = []
285
    while True:
286
        parts = os.path.split(path)
287
        if parts[0] == path:  # sentinel for absolute paths
288
            allparts.insert(0, parts[0])
289
            break
290
        elif parts[1] == path:  # sentinel for relative paths
291
            allparts.insert(0, parts[1])
292
            break
293
        else:
294
            path = parts[0]
295
            allparts.insert(0, parts[1])
296
    return allparts
297
298
299
def convert_path_to_module_parts(path):
300
    """Convert path to a python file into list of module names."""
301
    module_parts = splitall(path)
302
    if module_parts[-1] in ['__init__.py', '__init__.pyc']:
303
        # remove trailing __init__.py
304
        module_parts = module_parts[:-1]
305
    else:
306
        # remove file extension
307
        module_parts[-1], _ = os.path.splitext(module_parts[-1])
308
    return module_parts
309
310
311
def is_local_source(filename, modname, experiment_path):
312
    """Check if a module comes from the given experiment path.
313
314
    Check if a module, given by name and filename, is from (a subdirectory of )
315
    the given experiment path.
316
    This is used to determine if the module is a local source file, or rather
317
    a package dependency.
318
319
    Parameters
320
    ----------
321
    filename: str
322
        The absolute filename of the module in question.
323
        (Usually module.__file__)
324
    modname: str
325
        The full name of the module including parent namespaces.
326
    experiment_path: str
327
        The base path of the experiment.
328
329
    Returns
330
    -------
331
    bool:
332
        True if the module was imported locally from (a subdir of) the
333
        experiment_path, and False otherwise.
334
    """
335
    if not is_subdir(filename, experiment_path):
336
        return False
337
    rel_path = os.path.relpath(filename, experiment_path)
338
    path_parts = convert_path_to_module_parts(rel_path)
339
340
    mod_parts = modname.split('.')
341
    if path_parts == mod_parts:
342
        return True
343
    if len(path_parts) > len(mod_parts):
344
        return False
345
    abs_path_parts = convert_path_to_module_parts(os.path.abspath(filename))
346
    return all([p == m for p, m in zip(reversed(abs_path_parts),
347
                                       reversed(mod_parts))])
348
349
350
def get_main_file(globs):
351
    filename = globs.get('__file__')
352
353
    if filename is None:
354
        experiment_path = os.path.abspath(os.path.curdir)
355
        main = None
356
    else:
357
        main = Source.create(globs.get('__file__'))
358
        experiment_path = os.path.dirname(main.filename)
359
    return experiment_path, main
360
361
362
def iterate_imported_modules(globs):
363
    checked_modules = set(MODULE_BLACKLIST)
364
    for glob in globs.values():
365
        if isinstance(glob, module):
366
            mod_path = glob.__name__
367
        elif hasattr(glob, '__module__'):
368
            mod_path = glob.__module__
369
        else:
370
            continue  # pragma: no cover
371
372
        if not mod_path:
373
            continue
374
375
        for modname in iter_prefixes(mod_path):
376
            if modname in checked_modules:
377
                continue
378
            checked_modules.add(modname)
379
            mod = sys.modules.get(modname)
380
            if mod is not None:
381
                yield modname, mod
382
383
384
def iterate_all_python_files(base_path):
385
    # TODO support ignored directories/files
386
    for dirname, subdirlist, filelist in os.walk(base_path):
387
        if '__pycache__' in dirname:
388
            continue
389
        for filename in filelist:
390
            if filename.endswith('.py'):
391
                yield os.path.join(base_path, dirname, filename)
392
393
394
def iterate_sys_modules():
395
    items = list(sys.modules.items())
396
    for modname, mod in items:
397
        if modname not in MODULE_BLACKLIST and mod is not None:
398
            yield modname, mod
399
400
401
def get_sources_from_modules(module_iterator, base_path):
402
    sources = set()
403
    for modname, mod in module_iterator:
404
        if not hasattr(mod, '__file__'):
405
            continue
406
407
        filename = os.path.abspath(mod.__file__)
408
        if filename not in sources and \
409
                is_local_source(filename, modname, base_path):
410
            s = Source.create(filename)
411
            sources.add(s)
412
    return sources
413
414
415
def get_dependencies_from_modules(module_iterator, base_path):
416
    dependencies = set()
417
    for modname, mod in module_iterator:
418
        if hasattr(mod, '__file__') and is_local_source(
419
                os.path.abspath(mod.__file__), modname, base_path):
420
            continue
421
        if modname.startswith('_') or '.' in modname:
422
            continue
423
424
        try:
425
            pdep = PackageDependency.create(mod)
426
            if pdep.version is not None:
427
                dependencies.add(pdep)
428
        except AttributeError:
429
            pass
430
    return dependencies
431
432
433
def get_sources_from_sys_modules(globs, base_path):
434
    return get_sources_from_modules(iterate_sys_modules(), base_path)
435
436
437
def get_sources_from_imported_modules(globs, base_path):
438
    return get_sources_from_modules(iterate_imported_modules(globs), base_path)
439
440
441
def get_sources_from_local_dir(globs, base_path):
442
    return {Source.create(filename)
443
            for filename in iterate_all_python_files(base_path)}
444
445
446
def get_dependencies_from_sys_modules(globs, base_path):
447
    return get_dependencies_from_modules(iterate_sys_modules(), base_path)
448
449
450
def get_dependencies_from_imported_modules(globs, base_path):
451
    return get_dependencies_from_modules(iterate_imported_modules(globs),
452
                                         base_path)
453
454
455
def get_dependencies_from_pkg(globs, base_path):
456
    dependencies = set()
457
    for dist in pkg_resources.working_set:
458
        if dist.version == '0.0.0':
459
            continue  # ugly hack to deal with pkg-resource version bug
460
        dependencies.add(PackageDependency(dist.project_name, dist.version))
461
    return dependencies
462
463
464
source_discovery_strategies = {
465
    'none': lambda globs, path: set(),
466
    'imported': get_sources_from_imported_modules,
467
    'sys': get_sources_from_sys_modules,
468
    'dir': get_sources_from_local_dir
469
}
470
471
dependency_discovery_strategies = {
472
    'none': lambda globs, path: set(),
473
    'imported': get_dependencies_from_imported_modules,
474
    'sys': get_dependencies_from_sys_modules,
475
    'pkg': get_dependencies_from_pkg
476
}
477
478
479
def gather_sources_and_dependencies(globs, base_dir=None):
480
    """Scan the given globals for modules and return them as dependencies."""
481
482
    experiment_path, main = get_main_file(globs)
483
484
    base_dir = base_dir or experiment_path
485
486
    gather_sources = source_discovery_strategies[SETTINGS['DISCOVER_SOURCES']]
487
    sources = gather_sources(globs, base_dir)
488
    if main is not None:
489
        sources.add(main)
490
491
    gather_dependencies = dependency_discovery_strategies[
492
        SETTINGS['DISCOVER_DEPENDENCIES']]
493
    dependencies = gather_dependencies(globs, base_dir)
494
495
    if opt.has_numpy:
496
        # Add numpy as a dependency because it might be used for randomness
497
        dependencies.add(PackageDependency.create(opt.np))
498
499
    return main, sources, dependencies
500