Completed
Push — master ( 4e58e9...669a3f )
by Klaus
01:09
created

MongoDbOption.get_priority()   A

Complexity

Conditions 2

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
c 0
b 0
f 0
dl 0
loc 6
rs 9.4285
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import pickle
6
import re
7
import os.path
8
import sys
9
import time
10
11
import bson
12
import gridfs
13
import pymongo
14
import sacred.optional as opt
15
from pymongo.errors import AutoReconnect, InvalidDocument, DuplicateKeyError
16
from sacred.commandline_options import CommandLineOption
17
from sacred.dependencies import get_digest
18
from sacred.observers.base import RunObserver
19
from sacred.serializer import flatten
20
from sacred.utils import ObserverError
21
22
23
DEFAULT_MONGO_PRIORITY = 30
24
25
26
def force_valid_bson_key(key):
27
    key = str(key)
28
    if key.startswith('$'):
29
        key = '@' + key[1:]
30
    key = key.replace('.', ',')
31
    return key
32
33
34
def force_bson_encodeable(obj):
35
    if isinstance(obj, dict):
36
        try:
37
            bson.BSON.encode(obj, check_keys=True)
38
            return obj
39
        except bson.InvalidDocument:
40
            return {force_valid_bson_key(k): force_bson_encodeable(v)
41
                    for k, v in obj.items()}
42
43
    elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
44
        return obj
45
    else:
46
        try:
47
            bson.BSON.encode({'dict_just_for_testing': obj})
48
            return obj
49
        except bson.InvalidDocument:
50
            return str(obj)
51
52
53
class MongoObserver(RunObserver):
54
    COLLECTION_NAME_BLACKLIST = {'fs.files', 'fs.chunks', '_properties',
55
                                 'system.indexes', 'search_space'}
56
    VERSION = 'MongoObserver-0.7.0'
57
58
    @staticmethod
59
    def create(url='localhost', db_name='sacred', collection='runs',
60
               overwrite=None, priority=DEFAULT_MONGO_PRIORITY, **kwargs):
61
        client = pymongo.MongoClient(url, **kwargs)
62
        database = client[db_name]
63
        if collection in MongoObserver.COLLECTION_NAME_BLACKLIST:
64
            raise KeyError('Collection name "{}" is reserved. '
65
                           'Please use a different one.'.format(collection))
66
        runs_collection = database[collection]
67
        fs = gridfs.GridFS(database)
68
        return MongoObserver(runs_collection, fs, overwrite=overwrite,
69
                             priority=priority)
70
71
    def __init__(self, runs_collection, fs, overwrite=None,
72
                 priority=DEFAULT_MONGO_PRIORITY):
73
        self.runs = runs_collection
74
        self.fs = fs
75
        self.overwrite = overwrite
76
        self.run_entry = None
77
        self.priority = priority
78
79
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
80
                     _id):
81
        if self.overwrite is not None:
82
            raise RuntimeError("Can't overwrite with QUEUED run.")
83
        self.run_entry = {
84
            'experiment': dict(ex_info),
85
            'command': command,
86
            'config': flatten(config),
87
            'meta': meta_info,
88
            'status': 'QUEUED'
89
        }
90
        # set ID if given
91
        if _id is not None:
92
            self.run_entry['_id'] = _id
93
        # save sources
94
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
95
        self.insert()
96
        return self.run_entry['_id']
97
98
    def started_event(self, ex_info, command, host_info, start_time, config,
99
                      meta_info, _id):
100
        if self.overwrite is None:
101
            self.run_entry = {'_id': _id}
102
        else:
103
            if self.run_entry is not None:
104
                raise RuntimeError("Cannot overwrite more than once!")
105
            # sanity checks
106
            if self.overwrite['experiment']['sources'] != ex_info['sources']:
107
                raise RuntimeError("Sources don't match")
108
            self.run_entry = self.overwrite
109
110
        self.run_entry.update({
111
            'experiment': dict(ex_info),
112
            'format': self.VERSION,
113
            'command': command,
114
            'host': dict(host_info),
115
            'start_time': start_time,
116
            'config': flatten(config),
117
            'meta': meta_info,
118
            'status': 'RUNNING',
119
            'resources': [],
120
            'artifacts': [],
121
            'captured_out': '',
122
            'info': {},
123
            'heartbeat': None
124
        })
125
126
        # save sources
127
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
128
        self.insert()
129
        return self.run_entry['_id']
130
131
    def heartbeat_event(self, info, captured_out, beat_time):
132
        self.run_entry['info'] = flatten(info)
133
        self.run_entry['captured_out'] = captured_out
134
        self.run_entry['heartbeat'] = beat_time
135
        self.save()
136
137
    def completed_event(self, stop_time, result):
138
        self.run_entry['stop_time'] = stop_time
139
        self.run_entry['result'] = flatten(result)
140
        self.run_entry['status'] = 'COMPLETED'
141
        self.final_save(attempts=10)
142
143
    def interrupted_event(self, interrupt_time, status):
144
        self.run_entry['stop_time'] = interrupt_time
145
        self.run_entry['status'] = status
146
        self.final_save(attempts=3)
147
148
    def failed_event(self, fail_time, fail_trace):
149
        self.run_entry['stop_time'] = fail_time
150
        self.run_entry['status'] = 'FAILED'
151
        self.run_entry['fail_trace'] = fail_trace
152
        self.final_save(attempts=1)
153
154
    def resource_event(self, filename):
155
        if self.fs.exists(filename=filename):
156
            md5hash = get_digest(filename)
157
            if self.fs.exists(filename=filename, md5=md5hash):
158
                resource = (filename, md5hash)
159
                if resource not in self.run_entry['resources']:
160
                    self.run_entry['resources'].append(resource)
161
                    self.save()
162
                return
163
        with open(filename, 'rb') as f:
164
            file_id = self.fs.put(f, filename=filename)
165
        md5hash = self.fs.get(file_id).md5
166
        self.run_entry['resources'].append((filename, md5hash))
167
        self.save()
168
169
    def artifact_event(self, name, filename):
170
        with open(filename, 'rb') as f:
171
            run_id = self.run_entry['_id']
172
            db_filename = 'artifact://{}/{}/{}'.format(self.runs.name, run_id,
173
                                                       name)
174
            file_id = self.fs.put(f, filename=db_filename)
175
        self.run_entry['artifacts'].append({'name': name,
176
                                            'file_id': file_id})
177
        self.save()
178
179
    def insert(self):
180
        if self.overwrite:
181
            return self.save()
182
183
        autoinc_key = self.run_entry['_id'] is None
184
        while True:
185
            if autoinc_key:
186
                c = self.runs.find({}, {'_id': 1})
187
                c = c.sort('_id', pymongo.DESCENDING).limit(1)
188
                self.run_entry['_id'] = c.next()['_id'] + 1 if c.count() else 1
189
            try:
190
                self.runs.insert_one(self.run_entry)
191
            except InvalidDocument:
192
                raise ObserverError('Run contained an unserializable entry.'
193
                                    '(most likely in the info)')
194
            except DuplicateKeyError:
195
                if not autoinc_key:
196
                    raise
197
            return
198
199
    def save(self):
200
        try:
201
            self.runs.replace_one({'_id': self.run_entry['_id']},
202
                                  self.run_entry)
203
        except AutoReconnect:
204
            pass  # just wait for the next save
205
        except InvalidDocument:
206
            raise ObserverError('Run contained an unserializable entry.'
207
                                '(most likely in the info)')
208
209
    def final_save(self, attempts):
210
        for i in range(attempts):
211
            try:
212
                self.runs.save(self.run_entry)
213
                return
214
            except AutoReconnect:
215
                if i < attempts - 1:
216
                    time.sleep(1)
217
            except InvalidDocument:
218
                self.run_entry = force_bson_encodeable(self.run_entry)
219
                print("Warning: Some of the entries of the run were not "
220
                      "BSON-serializable!\n They have been altered such that "
221
                      "they can be stored, but you should fix your experiment!"
222
                      "Most likely it is either the 'info' or the 'result'.",
223
                      file=sys.stderr)
224
225
        from tempfile import NamedTemporaryFile
226
        with NamedTemporaryFile(suffix='.pickle', delete=False,
227
                                prefix='sacred_mongo_fail_') as f:
228
            pickle.dump(self.run_entry, f)
229
            print("Warning: saving to MongoDB failed! "
230
                  "Stored experiment entry in '{}'".format(f.name),
231
                  file=sys.stderr)
232
233
    def save_sources(self, ex_info):
234
        base_dir = ex_info['base_dir']
235
        source_info = []
236
        for source_name, md5 in ex_info['sources']:
237
            abs_path = os.path.join(base_dir, source_name)
238
            file = self.fs.find_one({'filename': abs_path, 'md5': md5})
239
            if file:
240
                _id = file._id
241
            else:
242
                with open(abs_path, 'rb') as f:
243
                    _id = self.fs.put(f, filename=abs_path)
244
            source_info.append([source_name, _id])
245
        return source_info
246
247
    def __eq__(self, other):
248
        if isinstance(other, MongoObserver):
249
            return self.runs == other.runs
250
        return False
251
252
    def __ne__(self, other):
253
        return not self.__eq__(other)
254
255
256
class MongoDbOption(CommandLineOption):
257
    """Add a MongoDB Observer to the experiment."""
258
259
    arg = 'DB'
260
    arg_description = "Database specification. Can be " \
261
                      "[host:port:]db_name[.collection][!priority]"
262
263
    DB_NAME_PATTERN = r"[_A-Za-z][0-9A-Za-z#%&'()+\-;=@\[\]^_{}.]{0,63}"
264
    HOSTNAME_PATTERN = \
265
        r"(?=.{1,255}$)"\
266
        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"\
267
        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*"\
268
        r"\.?"
269
    URL_PATTERN = "(?:" + HOSTNAME_PATTERN + ")" + ":" + "(?:[0-9]{1,5})"
270
    PRIORITY_PATTERN = "(?P<priority>!-?\d+)?"
271
    DB_NAME = re.compile("^" + DB_NAME_PATTERN + PRIORITY_PATTERN + "$")
272
    URL = re.compile("^" + URL_PATTERN + PRIORITY_PATTERN + "$")
273
    URL_DB_NAME = re.compile("^(?P<url>" + URL_PATTERN + ")" + ":" +
274
                             "(?P<db_name>" + DB_NAME_PATTERN + ")" +
275
                             PRIORITY_PATTERN + "$")
276
277
    @classmethod
278
    def apply(cls, args, run):
279
        url, db_name, collection, priority = cls.parse_mongo_db_arg(args)
280
        if collection:
281
            mongo = MongoObserver.create(db_name=db_name, url=url,
282
                                         collection=collection,
283
                                         priority=priority)
284
        else:
285
            mongo = MongoObserver.create(db_name=db_name, url=url,
286
                                         priority=priority)
287
288
        run.observers.append(mongo)
289
290
    @classmethod
291
    def parse_mongo_db_arg(cls, mongo_db):
292
        def get_priority(pattern):
293
            prio_str = pattern.match(mongo_db).group('priority')
294
            if prio_str is None:
295
                return DEFAULT_MONGO_PRIORITY
296
            else:
297
                return int(prio_str[1:])
298
299
        if cls.DB_NAME.match(mongo_db):
300
            db_name, _, collection = mongo_db.partition('.')
301
            return ('localhost:27017', db_name, collection,
302
                    get_priority(cls.DB_NAME))
303
        elif cls.URL.match(mongo_db):
304
            return mongo_db, 'sacred', '', get_priority(cls.URL)
305
        elif cls.URL_DB_NAME.match(mongo_db):
306
            match = cls.URL_DB_NAME.match(mongo_db)
307
            db_name, _, collection = match.group('db_name').partition('.')
308
            return (match.group('url'), db_name, collection,
309
                    get_priority(cls.URL_DB_NAME))
310
        else:
311
            raise ValueError('mongo_db argument must have the form "db_name" '
312
                             'or "host:port[:db_name]" but was {}'
313
                             .format(mongo_db))
314