Completed
Push — master ( 701aa5...50a1f3 )
by Klaus
01:33
created

gather_sources_and_dependencies()   F

Complexity

Conditions 9

Size

Total Lines 37

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 1 Features 2
Metric Value
cc 9
c 3
b 1
f 2
dl 0
loc 37
rs 3
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import functools
6
import hashlib
7
import os.path
8
import re
9
import sys
10
11
import pkg_resources
12
import six
13
import sacred.optional as opt
14
from sacred.utils import is_subdir, iter_prefixes
15
16
__sacred__ = True  # marks files that should be filtered from stack traces
17
18
MB = 1048576
19
MODULE_BLACKLIST = {None, '__future__', 'hashlib', 'os', 're'} | \
20
    set(sys.builtin_module_names)
21
module = type(sys)
22
PEP440_VERSION_PATTERN = re.compile(r"""
23
^
24
(\d+!)?              # epoch
25
(\d[\.\d]*(?<= \d))  # release
26
((?:[abc]|rc)\d+)?   # pre-release
27
(?:(\.post\d+))?     # post-release
28
(?:(\.dev\d+))?      # development release
29
$
30
""", flags=re.VERBOSE)
31
32
33
def get_py_file_if_possible(pyc_name):
34
    if pyc_name.endswith('.py'):
35
        return pyc_name
36
    assert pyc_name.endswith('.pyc')
37
    non_compiled_file = pyc_name[:-1]
38
    if os.path.exists(non_compiled_file):
39
        return non_compiled_file
40
    return pyc_name
41
42
43
def get_digest(filename):
44
    h = hashlib.md5()
45
    with open(filename, 'rb') as f:
46
        data = f.read(1 * MB)
47
        while data:
48
            h.update(data)
49
            data = f.read(1 * MB)
50
        return h.hexdigest()
51
52
53
@functools.total_ordering
54
class Source(object):
55
    def __init__(self, filename, digest):
56
        self.filename = filename
57
        self.digest = digest
58
59
    @staticmethod
60
    def create(filename):
61
        if not filename or not os.path.exists(filename):
62
            raise ValueError('invalid filename or file not found "{}"'
63
                             .format(filename))
64
65
        mainfile = get_py_file_if_possible(os.path.abspath(filename))
66
67
        return Source(mainfile, get_digest(mainfile))
68
69
    def to_tuple(self):
70
        return self.filename, self.digest
71
72
    def __hash__(self):
73
        return hash(self.filename)
74
75
    def __eq__(self, other):
76
        if isinstance(other, Source):
77
            return self.filename == other.filename
78
        else:
79
            return False
80
81
    def __le__(self, other):
82
        return self.filename.__le__(other.filename)
83
84
    def __repr__(self):
85
        return '<Source: {}>'.format(self.filename)
86
87
88
@functools.total_ordering
89
class PackageDependency(object):
90
    def __init__(self, name, version):
91
        self.name = name
92
        self.version = version
93
94
    def fill_missing_version(self):
95
        if self.version is not None:
96
            return
97
        try:
98
            self.version = pkg_resources.get_distribution(self.name).version
99
        except pkg_resources.DistributionNotFound:
100
            self.version = '<unknown>'
101
102
    def to_tuple(self):
103
        return self.name, self.version
104
105
    def __hash__(self):
106
        return hash(self.name)
107
108
    def __eq__(self, other):
109
        if isinstance(other, PackageDependency):
110
            return self.name == other.name
111
        else:
112
            return False
113
114
    def __le__(self, other):
115
        return self.name.__le__(other.name)
116
117
    def __repr__(self):
118
        return '<PackageDependency: {}={}>'.format(self.name, self.version)
119
120
    @staticmethod
121
    def get_version_heuristic(mod):
122
        possible_version_attributes = ['__version__', 'VERSION', 'version']
123
        for vattr in possible_version_attributes:
124
            if hasattr(mod, vattr):
125
                version = getattr(mod, vattr)
126
                if isinstance(version, six.string_types) and \
127
                        PEP440_VERSION_PATTERN.match(version):
128
                    return version
129
                if isinstance(version, tuple):
130
                    version = '.'.join([str(n) for n in version])
131
                    if PEP440_VERSION_PATTERN.match(version):
132
                        return version
133
134
        return None
135
136
    @staticmethod
137
    def create(mod):
138
        modname = mod.__name__
139
        version = PackageDependency.get_version_heuristic(mod)
140
        return PackageDependency(modname, version)
141
142
143
def create_source_or_dep(modname, mod, dependencies, sources, experiment_path):
144
    if modname in MODULE_BLACKLIST or modname in dependencies:
145
        return
146
147
    filename = ''
148
    if mod is not None and hasattr(mod, '__file__'):
149
        filename = os.path.abspath(mod.__file__)
150
151
    if filename and filename not in sources and \
152
            is_local_source(filename, modname, experiment_path):
153
        s = Source.create(filename)
154
        sources.add(s)
155
    elif mod is not None:
156
        pdep = PackageDependency.create(mod)
157
        if pdep.name.find('.') == -1 or pdep.version is not None:
158
            dependencies.add(pdep)
159
160
161
# Credit to Trent Mick from here:
162
# https://www.safaribooksonline.com/library/view/python-cookbook/0596001673/ch04s16.html
163
def splitall(path):
164
    allparts = []
165
    while 1:
166
        parts = os.path.split(path)
167
        if parts[0] == path:  # sentinel for absolute paths
168
            allparts.insert(0, parts[0])
169
            break
170
        elif parts[1] == path:  # sentinel for relative paths
171
            allparts.insert(0, parts[1])
172
            break
173
        else:
174
            path = parts[0]
175
            allparts.insert(0, parts[1])
176
    return allparts
177
178
179
def get_relevant_path_parts(path):
180
    path_parts = splitall(path)
181
    if path_parts[-1] in ['__init__.py', '__init__.pyc']:
182
        path_parts = path_parts[:-1]
183
    else:
184
        path_parts[-1], _ = os.path.splitext(path_parts[-1])
185
    return path_parts
186
187
188
def is_local_source(filename, modname, experiment_path):
189
    if not is_subdir(filename, experiment_path):
190
        return False
191
    rel_path = os.path.relpath(filename, experiment_path)
192
    path_parts = get_relevant_path_parts(rel_path)
193
194
    mod_parts = modname.split('.')
195
    if path_parts == mod_parts:
196
        return True
197
    if len(path_parts) > len(mod_parts):
198
        return False
199
    abs_path_parts = get_relevant_path_parts(os.path.abspath(filename))
200
    return all([p == m for p, m in zip(reversed(abs_path_parts),
201
                                       reversed(mod_parts))])
202
203
204
def gather_sources_and_dependencies(globs, interactive=False):
205
    dependencies = set()
206
    filename = globs.get('__file__')
207
208
    if filename is None:
209
        if not interactive:
210
            raise RuntimeError("Defining an experiment in interactive mode! "
211
                               "The sourcecode cannot be stored and the "
212
                               "experiment won't be reproducible. If you still"
213
                               " want to run it pass interactive=True")
214
        sources = set()
215
        experiment_path = os.path.abspath(os.path.curdir)
216
    else:
217
        main = Source.create(globs.get('__file__'))
218
        sources = {main}
219
        experiment_path = os.path.dirname(main.filename)
220
    for glob in globs.values():
221
        if isinstance(glob, module):
222
            mod_path = glob.__name__
223
        elif hasattr(glob, '__module__'):
224
            mod_path = glob.__module__
225
        else:
226
            continue  # pragma: no cover
227
228
        if not mod_path:
229
            continue
230
231
        for modname in iter_prefixes(mod_path):
232
            mod = sys.modules.get(modname)
233
            create_source_or_dep(modname, mod, dependencies, sources,
234
                                 experiment_path)
235
236
    if opt.has_numpy:
237
        # Add numpy as a dependency because it might be used for randomness
238
        dependencies.add(PackageDependency.create(opt.np))
239
240
    return sources, dependencies
241