Completed
Push — master ( c73972...b998ba )
by Klaus
32s
created

MongoDbOption   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 60
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 1
Metric Value
c 1
b 0
f 1
dl 0
loc 60
rs 10
wmc 7

2 Methods

Rating   Name   Duplication   Size   Complexity  
B parse_mongo_db_arg() 0 13 5
A apply() 0 13 2
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import pickle
6
import re
7
import os.path
8
import sys
9
import time
10
11
import sacred.optional as opt
12
from sacred.commandline_options import CommandLineOption
13
from sacred.dependencies import get_digest
14
from sacred.observers.base import RunObserver
15
from sacred.serializer import flatten
16
from sacred.utils import ObserverError
17
18
19
DEFAULT_MONGO_PRIORITY = 30
20
21
22
def force_valid_bson_key(key):
23
    key = str(key)
24
    if key.startswith('$'):
25
        key = '@' + key[1:]
26
    key = key.replace('.', ',')
27
    return key
28
29
30
def force_bson_encodeable(obj):
31
    import bson
32
    if isinstance(obj, dict):
33
        try:
34
            bson.BSON.encode(obj, check_keys=True)
35
            return obj
36
        except bson.InvalidDocument:
37
            return {force_valid_bson_key(k): force_bson_encodeable(v)
38
                    for k, v in obj.items()}
39
40
    elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
41
        return obj
42
    else:
43
        try:
44
            bson.BSON.encode({'dict_just_for_testing': obj})
45
            return obj
46
        except bson.InvalidDocument:
47
            return str(obj)
48
49
50
class MongoObserver(RunObserver):
51
    COLLECTION_NAME_BLACKLIST = {'fs.files', 'fs.chunks', '_properties',
52
                                 'system.indexes', 'seach_space'}
53
    VERSION = 'MongoObserver-0.7.0'
54
55
    @staticmethod
56
    def create(url='localhost', db_name='sacred', collection='runs',
57
               overwrite=None, priority=DEFAULT_MONGO_PRIORITY, **kwargs):
58
        import pymongo
59
        import gridfs
60
        client = pymongo.MongoClient(url, **kwargs)
61
        database = client[db_name]
62
        if collection in MongoObserver.COLLECTION_NAME_BLACKLIST:
63
            raise KeyError('Collection name "{}" is reserved. '
64
                           'Please use a different one.'.format(collection))
65
        runs_collection = database[collection]
66
        metrics_collection = database["metrics"]
67
        fs = gridfs.GridFS(database)
68
        return MongoObserver(runs_collection,
69
                             fs, overwrite=overwrite,
70
                             metrics_collection=metrics_collection,
71
                             priority=priority)
72
73
    def __init__(self, runs_collection,
74
                 fs, overwrite=None, metrics_collection=None,
75
                 priority=DEFAULT_MONGO_PRIORITY):
76
        self.runs = runs_collection
77
        self.metrics = metrics_collection
78
        self.fs = fs
79
        if isinstance(overwrite, (int, str)):
80
            run = self.runs.find_one({'_id': self.overwrite})
81
            if run is None:
82
                raise RuntimeError("Couldn't find run to overwrite with _id='{}'".format(self.overwrite))
83
            else:
84
                overwrite = run
85
        self.overwrite = overwrite
86
        self.run_entry = None
87
        self.priority = priority
88
89
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
90
                     _id):
91
        if self.overwrite is not None:
92
            raise RuntimeError("Can't overwrite with QUEUED run.")
93
        self.run_entry = {
94
            'experiment': dict(ex_info),
95
            'command': command,
96
            'config': flatten(config),
97
            'meta': meta_info,
98
            'status': 'QUEUED'
99
        }
100
        # set ID if given
101
        if _id is not None:
102
            self.run_entry['_id'] = _id
103
        # save sources
104
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
105
        self.insert()
106
        return self.run_entry['_id']
107
108
    def started_event(self, ex_info, command, host_info, start_time, config,
109
                      meta_info, _id):
110
        if self.overwrite is None:
111
            self.run_entry = {'_id': _id}
112
        else:
113
            if self.run_entry is not None:
114
                raise RuntimeError("Cannot overwrite more than once!")
115
            # sanity checks
116
            if self.overwrite['experiment']['sources'] != ex_info['sources']:
117
                raise RuntimeError("Sources don't match")
118
            self.run_entry = self.overwrite
119
120
        self.run_entry.update({
121
            'experiment': dict(ex_info),
122
            'format': self.VERSION,
123
            'command': command,
124
            'host': dict(host_info),
125
            'start_time': start_time,
126
            'config': flatten(config),
127
            'meta': meta_info,
128
            'status': 'RUNNING',
129
            'resources': [],
130
            'artifacts': [],
131
            'captured_out': '',
132
            'info': {},
133
            'heartbeat': None
134
        })
135
136
        # save sources
137
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
138
        self.insert()
139
        return self.run_entry['_id']
140
141
    def heartbeat_event(self, info, captured_out, beat_time, result):
142
        self.run_entry['info'] = flatten(info)
143
        self.run_entry['captured_out'] = captured_out
144
        self.run_entry['heartbeat'] = beat_time
145
        self.run_entry['result'] = flatten(result)
146
        self.save()
147
148
    def completed_event(self, stop_time, result):
149
        self.run_entry['stop_time'] = stop_time
150
        self.run_entry['result'] = flatten(result)
151
        self.run_entry['status'] = 'COMPLETED'
152
        self.final_save(attempts=10)
153
154
    def interrupted_event(self, interrupt_time, status):
155
        self.run_entry['stop_time'] = interrupt_time
156
        self.run_entry['status'] = status
157
        self.final_save(attempts=3)
158
159
    def failed_event(self, fail_time, fail_trace):
160
        self.run_entry['stop_time'] = fail_time
161
        self.run_entry['status'] = 'FAILED'
162
        self.run_entry['fail_trace'] = fail_trace
163
        self.final_save(attempts=1)
164
165
    def resource_event(self, filename):
166
        if self.fs.exists(filename=filename):
167
            md5hash = get_digest(filename)
168
            if self.fs.exists(filename=filename, md5=md5hash):
169
                resource = (filename, md5hash)
170
                if resource not in self.run_entry['resources']:
171
                    self.run_entry['resources'].append(resource)
172
                    self.save()
173
                return
174
        with open(filename, 'rb') as f:
175
            file_id = self.fs.put(f, filename=filename)
176
        md5hash = self.fs.get(file_id).md5
177
        self.run_entry['resources'].append((filename, md5hash))
178
        self.save()
179
180
    def artifact_event(self, name, filename):
181
        with open(filename, 'rb') as f:
182
            run_id = self.run_entry['_id']
183
            db_filename = 'artifact://{}/{}/{}'.format(self.runs.name, run_id,
184
                                                       name)
185
            file_id = self.fs.put(f, filename=db_filename)
186
        self.run_entry['artifacts'].append({'name': name,
187
                                            'file_id': file_id})
188
        self.save()
189
190
    def log_metrics(self, metrics_by_name, info):
191
        """Store new measurements to the database.
192
193
        Take measurements and store them into
194
        the metrics collection in the database.
195
        Additionally, reference the metrics
196
        in the info["metrics"] dictionary.
197
        """
198
        if self.metrics is None:
199
            # If, for whatever reason, the metrics collection has not been set
200
            # do not try to save there anything
201
            return
202
        for key in metrics_by_name:
203
            query = {"run_id": self.run_entry['_id'],
204
                     "name": key}
205
            push = {"steps": {"$each": metrics_by_name[key]["steps"]},
206
                    "values": {"$each": metrics_by_name[key]["values"]},
207
                    "timestamps": {"$each": metrics_by_name[key]["timestamps"]}
208
                    }
209
            update = {"$push": push}
210
            result = self.metrics.update_one(query, update, upsert=True)
211
            if result.upserted_id is not None:
212
                # This is the first time we are storing this metric
213
                info.setdefault("metrics", []) \
214
                    .append({"name": key, "id": str(result.upserted_id)})
215
216
    def insert(self):
217
        import pymongo.errors
218
219
        if self.overwrite:
220
            return self.save()
221
222
        autoinc_key = self.run_entry.get('_id') is None
223
        while True:
224
            if autoinc_key:
225
                c = self.runs.find({}, {'_id': 1})
226
                c = c.sort('_id', pymongo.DESCENDING).limit(1)
227
                self.run_entry['_id'] = c.next()['_id'] + 1 if c.count() else 1
228
            try:
229
                self.runs.insert_one(self.run_entry)
230
            except pymongo.errors.InvalidDocument as e:
231
                raise ObserverError('Run contained an unserializable entry.'
232
                                    '(most likely in the info)\n{}'.format(e))
233
            except pymongo.errors.DuplicateKeyError:
234
                if not autoinc_key:
235
                    raise
236
            return
237
238
    def save(self):
239
        import pymongo.errors
240
241
        try:
242
            self.runs.replace_one({'_id': self.run_entry['_id']},
243
                                  self.run_entry)
244
        except pymongo.errors.AutoReconnect:
245
            pass  # just wait for the next save
246
        except pymongo.errors.InvalidDocument:
247
            raise ObserverError('Run contained an unserializable entry.'
248
                                '(most likely in the info)')
249
250
    def final_save(self, attempts):
251
        import pymongo.errors
252
253
        for i in range(attempts):
254
            try:
255
                self.runs.save(self.run_entry)
256
                return
257
            except pymongo.errors.AutoReconnect:
258
                if i < attempts - 1:
259
                    time.sleep(1)
260
            except pymongo.errors.InvalidDocument:
261
                self.run_entry = force_bson_encodeable(self.run_entry)
262
                print("Warning: Some of the entries of the run were not "
263
                      "BSON-serializable!\n They have been altered such that "
264
                      "they can be stored, but you should fix your experiment!"
265
                      "Most likely it is either the 'info' or the 'result'.",
266
                      file=sys.stderr)
267
268
        from tempfile import NamedTemporaryFile
269
        with NamedTemporaryFile(suffix='.pickle', delete=False,
270
                                prefix='sacred_mongo_fail_') as f:
271
            pickle.dump(self.run_entry, f)
272
            print("Warning: saving to MongoDB failed! "
273
                  "Stored experiment entry in '{}'".format(f.name),
274
                  file=sys.stderr)
275
276
    def save_sources(self, ex_info):
277
        base_dir = ex_info['base_dir']
278
        source_info = []
279
        for source_name, md5 in ex_info['sources']:
280
            abs_path = os.path.join(base_dir, source_name)
281
            file = self.fs.find_one({'filename': abs_path, 'md5': md5})
282
            if file:
283
                _id = file._id
284
            else:
285
                with open(abs_path, 'rb') as f:
286
                    _id = self.fs.put(f, filename=abs_path)
287
            source_info.append([source_name, _id])
288
        return source_info
289
290
    def __eq__(self, other):
291
        if isinstance(other, MongoObserver):
292
            return self.runs == other.runs
293
        return False
294
295
    def __ne__(self, other):
296
        return not self.__eq__(other)
297
298
299
class MongoDbOption(CommandLineOption):
300
    """Add a MongoDB Observer to the experiment."""
301
302
    __depends_on__ = 'pymongo'
303
304
    arg = 'DB'
305
    arg_description = "Database specification. Can be " \
306
                      "[host:port:]db_name[.collection[:id]][!priority]"
307
308
    RUN_ID_PATTERN = r"(?P<run_id>\d{1,12})"
309
    PORT1_PATTERN = r"(?P<port1>\d{1,5})"
310
    PORT2_PATTERN = r"(?P<port2>\d{1,5})"
311
    PRIORITY_PATTERN = "(?P<priority>-?\d+)?"
312
    DB_NAME_PATTERN = r"(?P<db>[_A-Za-z][0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
313
    COLL_NAME_PATTERN = r"(?P<coll>[_A-Za-z][0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
314
    HOSTNAME1_PATTERN = r"(?P<host1>" \
315
                        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
316
                        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*)"
317
    HOSTNAME2_PATTERN = r"(?P<host2>" \
318
                        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
319
                        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*)"
320
321
    HOST_ONLY = r"^(?:{host}:{port})$".format(host=HOSTNAME1_PATTERN, port=PORT1_PATTERN)
322
    FULL = r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?(?:!{priority})?$".format(
323
        host=HOSTNAME2_PATTERN,
324
        port=PORT2_PATTERN,
325
        db=DB_NAME_PATTERN,
326
        collection=COLL_NAME_PATTERN,
327
        rid=RUN_ID_PATTERN,
328
        priority=PRIORITY_PATTERN)
329
330
    PATTERN = r"{host_only}|{full}".format(host_only=HOST_ONLY, full=FULL)
331
332
    @classmethod
333
    def apply(cls, args, run):
334
        url, db_name, collection, run_id, priority = cls.parse_mongo_db_arg(args)
335
        if collection:
336
            mongo = MongoObserver.create(db_name=db_name, url=url,
337
                                         collection=collection,
338
                                         priority=priority,
339
                                         overwrite=run_id)
340
        else:
341
            mongo = MongoObserver.create(db_name=db_name, url=url,
342
                                         priority=priority)
343
344
        run.observers.append(mongo)
345
346
    @classmethod
347
    def parse_mongo_db_arg(cls, mongo_db):
348
        g = re.match(cls.PATTERN, mongo_db).groupdict()
349
        if g is None:
350
            raise ValueError('mongo_db argument must have the form "db_name" '
351
                             'or "host:port[:db_name]" but was {}'
352
                             .format(mongo_db))
353
        if g['host1']:
354
            return '{}:{}'.format(g['host1'], g['port1']), 'sacred', None, None, DEFAULT_MONGO_PRIORITY
355
        else:
356
            host, port = (g['host2'], g['port2']) if g['host2'] else ('localhost', 27017)
357
            priority = int(g['priority']) if g['priority'] is not None else DEFAULT_MONGO_PRIORITY
358
            return '{}:{}'.format(host, port), g['db'], g['coll'], g['run_id'], priority