MongoObserver.__eq__()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
c 0
b 0
f 0
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import pickle
6
import re
7
import os.path
8
import sys
9
import time
10
11
import sacred.optional as opt
12
from sacred.commandline_options import CommandLineOption
13
from sacred.dependencies import get_digest
14
from sacred.observers.base import RunObserver
15
from sacred.serializer import flatten
16
from sacred.utils import ObserverError
17
18
19
DEFAULT_MONGO_PRIORITY = 30
20
21
22
def force_valid_bson_key(key):
23
    key = str(key)
24
    if key.startswith('$'):
25
        key = '@' + key[1:]
26
    key = key.replace('.', ',')
27
    return key
28
29
30
def force_bson_encodeable(obj):
31
    import bson
32
    if isinstance(obj, dict):
33
        try:
34
            bson.BSON.encode(obj, check_keys=True)
35
            return obj
36
        except bson.InvalidDocument:
37
            return {force_valid_bson_key(k): force_bson_encodeable(v)
38
                    for k, v in obj.items()}
39
40
    elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
41
        return obj
42
    else:
43
        try:
44
            bson.BSON.encode({'dict_just_for_testing': obj})
45
            return obj
46
        except bson.InvalidDocument:
47
            return str(obj)
48
49
50
class MongoObserver(RunObserver):
51
    COLLECTION_NAME_BLACKLIST = {'fs.files', 'fs.chunks', '_properties',
52
                                 'system.indexes', 'seach_space'}
53
    VERSION = 'MongoObserver-0.7.0'
54
55
    @staticmethod
56
    def create(url=None, db_name='sacred', collection='runs',
57
               overwrite=None, priority=DEFAULT_MONGO_PRIORITY,
58
               client=None, **kwargs):
59
        import pymongo
60
        import gridfs
61
62
        if client is not None:
63
            assert isinstance(client, pymongo.MongoClient)
64
            assert url is None, 'Cannot pass both a client and an url.'
65
        else:
66
            client = pymongo.MongoClient(url, **kwargs)
67
        database = client[db_name]
68
        if collection in MongoObserver.COLLECTION_NAME_BLACKLIST:
69
            raise KeyError('Collection name "{}" is reserved. '
70
                           'Please use a different one.'.format(collection))
71
        runs_collection = database[collection]
72
        metrics_collection = database["metrics"]
73
        fs = gridfs.GridFS(database)
74
        return MongoObserver(runs_collection,
75
                             fs, overwrite=overwrite,
76
                             metrics_collection=metrics_collection,
77
                             priority=priority)
78
79
    def __init__(self, runs_collection,
80
                 fs, overwrite=None, metrics_collection=None,
81
                 priority=DEFAULT_MONGO_PRIORITY):
82
        self.runs = runs_collection
83
        self.metrics = metrics_collection
84
        self.fs = fs
85
        if isinstance(overwrite, (int, str)):
86
            overwrite = int(overwrite)
87
            run = self.runs.find_one({'_id': overwrite})
88
            if run is None:
89
                raise RuntimeError("Couldn't find run to overwrite with "
90
                                   "_id='{}'".format(overwrite))
91
            else:
92
                overwrite = run
93
        self.overwrite = overwrite
94
        self.run_entry = None
95
        self.priority = priority
96
97
    def queued_event(self, ex_info, command, host_info, queue_time, config,
98
                     meta_info, _id):
99
        if self.overwrite is not None:
100
            raise RuntimeError("Can't overwrite with QUEUED run.")
101
        self.run_entry = {
102
            'experiment': dict(ex_info),
103
            'command': command,
104
            'host': dict(host_info),
105
            'config': flatten(config),
106
            'meta': meta_info,
107
            'status': 'QUEUED'
108
        }
109
        # set ID if given
110
        if _id is not None:
111
            self.run_entry['_id'] = _id
112
        # save sources
113
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
114
        self.insert()
115
        return self.run_entry['_id']
116
117
    def started_event(self, ex_info, command, host_info, start_time, config,
118
                      meta_info, _id):
119
        if self.overwrite is None:
120
            self.run_entry = {'_id': _id}
121
        else:
122
            if self.run_entry is not None:
123
                raise RuntimeError("Cannot overwrite more than once!")
124
            # TODO sanity checks
125
            self.run_entry = self.overwrite
126
127
        self.run_entry.update({
128
            'experiment': dict(ex_info),
129
            'format': self.VERSION,
130
            'command': command,
131
            'host': dict(host_info),
132
            'start_time': start_time,
133
            'config': flatten(config),
134
            'meta': meta_info,
135
            'status': 'RUNNING',
136
            'resources': [],
137
            'artifacts': [],
138
            'captured_out': '',
139
            'info': {},
140
            'heartbeat': None
141
        })
142
143
        # save sources
144
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
145
        self.insert()
146
        return self.run_entry['_id']
147
148
    def heartbeat_event(self, info, captured_out, beat_time, result):
149
        self.run_entry['info'] = flatten(info)
150
        self.run_entry['captured_out'] = captured_out
151
        self.run_entry['heartbeat'] = beat_time
152
        self.run_entry['result'] = flatten(result)
153
        self.save()
154
155
    def completed_event(self, stop_time, result):
156
        self.run_entry['stop_time'] = stop_time
157
        self.run_entry['result'] = flatten(result)
158
        self.run_entry['status'] = 'COMPLETED'
159
        self.final_save(attempts=10)
160
161
    def interrupted_event(self, interrupt_time, status):
162
        self.run_entry['stop_time'] = interrupt_time
163
        self.run_entry['status'] = status
164
        self.final_save(attempts=3)
165
166
    def failed_event(self, fail_time, fail_trace):
167
        self.run_entry['stop_time'] = fail_time
168
        self.run_entry['status'] = 'FAILED'
169
        self.run_entry['fail_trace'] = fail_trace
170
        self.final_save(attempts=1)
171
172
    def resource_event(self, filename):
173
        if self.fs.exists(filename=filename):
174
            md5hash = get_digest(filename)
175
            if self.fs.exists(filename=filename, md5=md5hash):
176
                resource = (filename, md5hash)
177
                if resource not in self.run_entry['resources']:
178
                    self.run_entry['resources'].append(resource)
179
                    self.save()
180
                return
181
        with open(filename, 'rb') as f:
182
            file_id = self.fs.put(f, filename=filename)
183
        md5hash = self.fs.get(file_id).md5
184
        self.run_entry['resources'].append((filename, md5hash))
185
        self.save()
186
187
    def artifact_event(self, name, filename):
188
        with open(filename, 'rb') as f:
189
            run_id = self.run_entry['_id']
190
            db_filename = 'artifact://{}/{}/{}'.format(self.runs.name, run_id,
191
                                                       name)
192
            file_id = self.fs.put(f, filename=db_filename)
193
        self.run_entry['artifacts'].append({'name': name,
194
                                            'file_id': file_id})
195
        self.save()
196
197
    def log_metrics(self, metrics_by_name, info):
198
        """Store new measurements to the database.
199
200
        Take measurements and store them into
201
        the metrics collection in the database.
202
        Additionally, reference the metrics
203
        in the info["metrics"] dictionary.
204
        """
205
        if self.metrics is None:
206
            # If, for whatever reason, the metrics collection has not been set
207
            # do not try to save there anything
208
            return
209
        for key in metrics_by_name:
210
            query = {"run_id": self.run_entry['_id'],
211
                     "name": key}
212
            push = {"steps": {"$each": metrics_by_name[key]["steps"]},
213
                    "values": {"$each": metrics_by_name[key]["values"]},
214
                    "timestamps": {"$each": metrics_by_name[key]["timestamps"]}
215
                    }
216
            update = {"$push": push}
217
            result = self.metrics.update_one(query, update, upsert=True)
218
            if result.upserted_id is not None:
219
                # This is the first time we are storing this metric
220
                info.setdefault("metrics", []) \
221
                    .append({"name": key, "id": str(result.upserted_id)})
222
223
    def insert(self):
224
        import pymongo.errors
225
226
        if self.overwrite:
227
            return self.save()
228
229
        autoinc_key = self.run_entry.get('_id') is None
230
        while True:
231
            if autoinc_key:
232
                c = self.runs.find({}, {'_id': 1})
233
                c = c.sort('_id', pymongo.DESCENDING).limit(1)
234
                self.run_entry['_id'] = c.next()['_id'] + 1 if c.count() else 1
235
            try:
236
                self.runs.insert_one(self.run_entry)
237
                return
238
            except pymongo.errors.InvalidDocument as e:
239
                raise ObserverError('Run contained an unserializable entry.'
240
                                    '(most likely in the info)\n{}'.format(e))
241
            except pymongo.errors.DuplicateKeyError:
242
                if not autoinc_key:
243
                    raise
244
245
    def save(self):
246
        import pymongo.errors
247
248
        try:
249
            self.runs.replace_one({'_id': self.run_entry['_id']},
250
                                  self.run_entry)
251
        except pymongo.errors.AutoReconnect:
252
            pass  # just wait for the next save
253
        except pymongo.errors.InvalidDocument:
254
            raise ObserverError('Run contained an unserializable entry.'
255
                                '(most likely in the info)')
256
257
    def final_save(self, attempts):
258
        import pymongo.errors
259
260
        for i in range(attempts):
261
            try:
262
                self.runs.save(self.run_entry)
263
                return
264
            except pymongo.errors.AutoReconnect:
265
                if i < attempts - 1:
266
                    time.sleep(1)
267
            except pymongo.errors.InvalidDocument:
268
                self.run_entry = force_bson_encodeable(self.run_entry)
269
                print("Warning: Some of the entries of the run were not "
270
                      "BSON-serializable!\n They have been altered such that "
271
                      "they can be stored, but you should fix your experiment!"
272
                      "Most likely it is either the 'info' or the 'result'.",
273
                      file=sys.stderr)
274
275
        from tempfile import NamedTemporaryFile
276
        with NamedTemporaryFile(suffix='.pickle', delete=False,
277
                                prefix='sacred_mongo_fail_') as f:
278
            pickle.dump(self.run_entry, f)
279
            print("Warning: saving to MongoDB failed! "
280
                  "Stored experiment entry in '{}'".format(f.name),
281
                  file=sys.stderr)
282
283
    def save_sources(self, ex_info):
284
        base_dir = ex_info['base_dir']
285
        source_info = []
286
        for source_name, md5 in ex_info['sources']:
287
            abs_path = os.path.join(base_dir, source_name)
288
            file = self.fs.find_one({'filename': abs_path, 'md5': md5})
289
            if file:
290
                _id = file._id
291
            else:
292
                with open(abs_path, 'rb') as f:
293
                    _id = self.fs.put(f, filename=abs_path)
294
            source_info.append([source_name, _id])
295
        return source_info
296
297
    def __eq__(self, other):
298
        if isinstance(other, MongoObserver):
299
            return self.runs == other.runs
300
        return False
301
302
    def __ne__(self, other):
303
        return not self.__eq__(other)
304
305
306
class MongoDbOption(CommandLineOption):
307
    """Add a MongoDB Observer to the experiment."""
308
309
    __depends_on__ = 'pymongo'
310
311
    arg = 'DB'
312
    arg_description = "Database specification. Can be " \
313
                      "[host:port:]db_name[.collection[:id]][!priority]"
314
315
    RUN_ID_PATTERN = r"(?P<overwrite>\d{1,12})"
316
    PORT1_PATTERN = r"(?P<port1>\d{1,5})"
317
    PORT2_PATTERN = r"(?P<port2>\d{1,5})"
318
    PRIORITY_PATTERN = "(?P<priority>-?\d+)?"
319
    DB_NAME_PATTERN = r"(?P<db_name>[_A-Za-z]" \
320
                      r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
321
    COLL_NAME_PATTERN = r"(?P<collection>[_A-Za-z]" \
322
                        r"[0-9A-Za-z#%&'()+\-;=@\[\]^_{}]{0,63})"
323
    HOSTNAME1_PATTERN = r"(?P<host1>" \
324
                        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
325
                        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" \
326
                        r"[0-9A-Za-z])?)*)"
327
    HOSTNAME2_PATTERN = r"(?P<host2>" \
328
                        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
329
                        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}" \
330
                        r"[0-9A-Za-z])?)*)"
331
332
    HOST_ONLY = r"^(?:{host}:{port})$".format(host=HOSTNAME1_PATTERN,
333
                                              port=PORT1_PATTERN)
334
    FULL = r"^(?:{host}:{port}:)?{db}(?:\.{collection}(?::{rid})?)?" \
335
           r"(?:!{priority})?$".format(host=HOSTNAME2_PATTERN,
336
                                       port=PORT2_PATTERN,
337
                                       db=DB_NAME_PATTERN,
338
                                       collection=COLL_NAME_PATTERN,
339
                                       rid=RUN_ID_PATTERN,
340
                                       priority=PRIORITY_PATTERN)
341
342
    PATTERN = r"{host_only}|{full}".format(host_only=HOST_ONLY, full=FULL)
343
344
    @classmethod
345
    def apply(cls, args, run):
346
        kwargs = cls.parse_mongo_db_arg(args)
347
        mongo = MongoObserver.create(**kwargs)
348
        run.observers.append(mongo)
349
350
    @classmethod
351
    def parse_mongo_db_arg(cls, mongo_db):
352
        g = re.match(cls.PATTERN, mongo_db).groupdict()
353
        if g is None:
354
            raise ValueError('mongo_db argument must have the form "db_name" '
355
                             'or "host:port[:db_name]" but was {}'
356
                             .format(mongo_db))
357
358
        kwargs = {}
359
        if g['host1']:
360
            kwargs['url'] = '{}:{}'.format(g['host1'], g['port1'])
361
        elif g['host2']:
362
            kwargs['url'] = '{}:{}'.format(g['host2'], g['port2'])
363
364
        if g['priority'] is not None:
365
            kwargs['priority'] = int(g['priority'])
366
367
        for p in ['db_name', 'collection', 'overwrite']:
368
            if g[p] is not None:
369
                kwargs[p] = g[p]
370
371
        return kwargs
372