Completed
Pull Request — master (#184)
by Martin
44s
created

MongoObserver   B

Complexity

Total Complexity 49

Size/Duplication

Total Lines 233
Duplicated Lines 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
c 4
b 0
f 0
dl 0
loc 233
rs 8.5454
wmc 49

17 Methods

Rating   Name   Duplication   Size   Complexity  
A heartbeat_event() 0 6 1
A completed_event() 0 5 1
B final_save() 0 23 6
A __eq__() 0 4 2
B started_event() 0 32 4
A save() 0 8 3
A create() 0 15 2
A failed_event() 0 5 1
A __ne__() 0 2 1
A queued_event() 0 18 3
A interrupted_event() 0 4 1
B resource_event() 0 14 5
B log_metrics() 0 25 4
A save_sources() 0 13 4
A artifact_event() 0 9 2
C insert() 0 19 8
A __init__() 0 9 1

How to fix   Complexity   

Complex Class

Complex classes like MongoObserver often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import pickle
6
import re
7
import os.path
8
import sys
9
import time
10
11
import bson
12
import gridfs
13
import pymongo
14
import sacred.optional as opt
15
from pymongo.errors import AutoReconnect, InvalidDocument, DuplicateKeyError
16
from sacred.commandline_options import CommandLineOption
17
from sacred.dependencies import get_digest
18
from sacred.observers.base import RunObserver
19
from sacred.serializer import flatten
20
from sacred.utils import ObserverError
21
22
23
DEFAULT_MONGO_PRIORITY = 30
24
25
26
def force_valid_bson_key(key):
27
    key = str(key)
28
    if key.startswith('$'):
29
        key = '@' + key[1:]
30
    key = key.replace('.', ',')
31
    return key
32
33
34
def force_bson_encodeable(obj):
35
    if isinstance(obj, dict):
36
        try:
37
            bson.BSON.encode(obj, check_keys=True)
38
            return obj
39
        except bson.InvalidDocument:
40
            return {force_valid_bson_key(k): force_bson_encodeable(v)
41
                    for k, v in obj.items()}
42
43
    elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
44
        return obj
45
    else:
46
        try:
47
            bson.BSON.encode({'dict_just_for_testing': obj})
48
            return obj
49
        except bson.InvalidDocument:
50
            return str(obj)
51
52
53
class MongoObserver(RunObserver):
54
    COLLECTION_NAME_BLACKLIST = {'fs.files', 'fs.chunks', '_properties',
55
                                 'system.indexes', 'seach_space'}
56
    VERSION = 'MongoObserver-0.7.0'
57
58
    @staticmethod
59
    def create(url='localhost', db_name='sacred', collection='runs',
60
               overwrite=None, priority=DEFAULT_MONGO_PRIORITY, **kwargs):
61
        client = pymongo.MongoClient(url, **kwargs)
62
        database = client[db_name]
63
        if collection in MongoObserver.COLLECTION_NAME_BLACKLIST:
64
            raise KeyError('Collection name "{}" is reserved. '
65
                           'Please use a different one.'.format(collection))
66
        runs_collection = database[collection]
67
        metrics_collection = database["metrics"]
68
        fs = gridfs.GridFS(database)
69
        return MongoObserver(runs_collection,
70
                             fs, overwrite=overwrite,
71
                             metrics_collection=metrics_collection,
72
                             priority=priority)
73
74
    def __init__(self, runs_collection,
75
                 fs, overwrite=None, metrics_collection=None,
76
                 priority=DEFAULT_MONGO_PRIORITY):
77
        self.runs = runs_collection
78
        self.metrics = metrics_collection
79
        self.fs = fs
80
        self.overwrite = overwrite
81
        self.run_entry = None
82
        self.priority = priority
83
84
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
85
                     _id):
86
        if self.overwrite is not None:
87
            raise RuntimeError("Can't overwrite with QUEUED run.")
88
        self.run_entry = {
89
            'experiment': dict(ex_info),
90
            'command': command,
91
            'config': flatten(config),
92
            'meta': meta_info,
93
            'status': 'QUEUED'
94
        }
95
        # set ID if given
96
        if _id is not None:
97
            self.run_entry['_id'] = _id
98
        # save sources
99
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
100
        self.insert()
101
        return self.run_entry['_id']
102
103
    def started_event(self, ex_info, command, host_info, start_time, config,
104
                      meta_info, _id):
105
        if self.overwrite is None:
106
            self.run_entry = {'_id': _id}
107
        else:
108
            if self.run_entry is not None:
109
                raise RuntimeError("Cannot overwrite more than once!")
110
            # sanity checks
111
            if self.overwrite['experiment']['sources'] != ex_info['sources']:
112
                raise RuntimeError("Sources don't match")
113
            self.run_entry = self.overwrite
114
115
        self.run_entry.update({
116
            'experiment': dict(ex_info),
117
            'format': self.VERSION,
118
            'command': command,
119
            'host': dict(host_info),
120
            'start_time': start_time,
121
            'config': flatten(config),
122
            'meta': meta_info,
123
            'status': 'RUNNING',
124
            'resources': [],
125
            'artifacts': [],
126
            'captured_out': '',
127
            'info': {},
128
            'heartbeat': None
129
        })
130
131
        # save sources
132
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
133
        self.insert()
134
        return self.run_entry['_id']
135
136
    def heartbeat_event(self, info, captured_out, beat_time, result):
137
        self.run_entry['info'] = flatten(info)
138
        self.run_entry['captured_out'] = captured_out
139
        self.run_entry['heartbeat'] = beat_time
140
        self.run_entry['result'] = result
141
        self.save()
142
143
    def completed_event(self, stop_time, result):
144
        self.run_entry['stop_time'] = stop_time
145
        self.run_entry['result'] = flatten(result)
146
        self.run_entry['status'] = 'COMPLETED'
147
        self.final_save(attempts=10)
148
149
    def interrupted_event(self, interrupt_time, status):
150
        self.run_entry['stop_time'] = interrupt_time
151
        self.run_entry['status'] = status
152
        self.final_save(attempts=3)
153
154
    def failed_event(self, fail_time, fail_trace):
155
        self.run_entry['stop_time'] = fail_time
156
        self.run_entry['status'] = 'FAILED'
157
        self.run_entry['fail_trace'] = fail_trace
158
        self.final_save(attempts=1)
159
160
    def resource_event(self, filename):
161
        if self.fs.exists(filename=filename):
162
            md5hash = get_digest(filename)
163
            if self.fs.exists(filename=filename, md5=md5hash):
164
                resource = (filename, md5hash)
165
                if resource not in self.run_entry['resources']:
166
                    self.run_entry['resources'].append(resource)
167
                    self.save()
168
                return
169
        with open(filename, 'rb') as f:
170
            file_id = self.fs.put(f, filename=filename)
171
        md5hash = self.fs.get(file_id).md5
172
        self.run_entry['resources'].append((filename, md5hash))
173
        self.save()
174
175
    def artifact_event(self, name, filename):
176
        with open(filename, 'rb') as f:
177
            run_id = self.run_entry['_id']
178
            db_filename = 'artifact://{}/{}/{}'.format(self.runs.name, run_id,
179
                                                       name)
180
            file_id = self.fs.put(f, filename=db_filename)
181
        self.run_entry['artifacts'].append({'name': name,
182
                                            'file_id': file_id})
183
        self.save()
184
185
    def log_metrics(self, metrics_by_name, info):
186
        """Store new measurements to the database.
187
188
        Take measurements and store them into
189
        the metrics collection in the database.
190
        Additionally, reference the metrics
191
        in the info["metrics"] dictionary.
192
        """
193
        if self.metrics is None:
194
            # If, for whatever reason, the metrics collection has not been set
195
            # do not try to save there anything
196
            return
197
        for key in metrics_by_name:
198
            query = {"run_id": self.run_entry['_id'],
199
                     "name": key}
200
            push = {"steps": {"$each": metrics_by_name[key]["steps"]},
201
                    "values": {"$each": metrics_by_name[key]["values"]},
202
                    "timestamps": {"$each": metrics_by_name[key]["timestamps"]}
203
                    }
204
            update = {"$push": push}
205
            result = self.metrics.update_one(query, update, upsert=True)
206
            if result.upserted_id is not None:
207
                # This is the first time we are storing this metric
208
                info.setdefault("metrics", []) \
209
                    .append({"name": key, "id": str(result.upserted_id)})
210
211
    def insert(self):
212
        if self.overwrite:
213
            return self.save()
214
215
        autoinc_key = self.run_entry.get('_id') is None
216
        while True:
217
            if autoinc_key:
218
                c = self.runs.find({}, {'_id': 1})
219
                c = c.sort('_id', pymongo.DESCENDING).limit(1)
220
                self.run_entry['_id'] = c.next()['_id'] + 1 if c.count() else 1
221
            try:
222
                self.runs.insert_one(self.run_entry)
223
            except InvalidDocument:
224
                raise ObserverError('Run contained an unserializable entry.'
225
                                    '(most likely in the info)')
226
            except DuplicateKeyError:
227
                if not autoinc_key:
228
                    raise
229
            return
230
231
    def save(self):
232
        try:
233
            self.runs.replace_one({'_id': self.run_entry['_id']},
234
                                  self.run_entry)
235
        except AutoReconnect:
236
            pass  # just wait for the next save
237
        except InvalidDocument:
238
            raise ObserverError('Run contained an unserializable entry.'
239
                                '(most likely in the info)')
240
241
    def final_save(self, attempts):
242
        for i in range(attempts):
243
            try:
244
                self.runs.save(self.run_entry)
245
                return
246
            except AutoReconnect:
247
                if i < attempts - 1:
248
                    time.sleep(1)
249
            except InvalidDocument:
250
                self.run_entry = force_bson_encodeable(self.run_entry)
251
                print("Warning: Some of the entries of the run were not "
252
                      "BSON-serializable!\n They have been altered such that "
253
                      "they can be stored, but you should fix your experiment!"
254
                      "Most likely it is either the 'info' or the 'result'.",
255
                      file=sys.stderr)
256
257
        from tempfile import NamedTemporaryFile
258
        with NamedTemporaryFile(suffix='.pickle', delete=False,
259
                                prefix='sacred_mongo_fail_') as f:
260
            pickle.dump(self.run_entry, f)
261
            print("Warning: saving to MongoDB failed! "
262
                  "Stored experiment entry in '{}'".format(f.name),
263
                  file=sys.stderr)
264
265
    def save_sources(self, ex_info):
266
        base_dir = ex_info['base_dir']
267
        source_info = []
268
        for source_name, md5 in ex_info['sources']:
269
            abs_path = os.path.join(base_dir, source_name)
270
            file = self.fs.find_one({'filename': abs_path, 'md5': md5})
271
            if file:
272
                _id = file._id
273
            else:
274
                with open(abs_path, 'rb') as f:
275
                    _id = self.fs.put(f, filename=abs_path)
276
            source_info.append([source_name, _id])
277
        return source_info
278
279
    def __eq__(self, other):
280
        if isinstance(other, MongoObserver):
281
            return self.runs == other.runs
282
        return False
283
284
    def __ne__(self, other):
285
        return not self.__eq__(other)
286
287
288
class MongoDbOption(CommandLineOption):
289
    """Add a MongoDB Observer to the experiment."""
290
291
    arg = 'DB'
292
    arg_description = "Database specification. Can be " \
293
                      "[host:port:]db_name[.collection][!priority]"
294
295
    DB_NAME_PATTERN = r"[_A-Za-z][0-9A-Za-z#%&'()+\-;=@\[\]^_{}.]{0,63}"
296
    HOSTNAME_PATTERN = \
297
        r"(?=.{1,255}$)" \
298
        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?" \
299
        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*" \
300
        r"\.?"
301
    URL_PATTERN = "(?:" + HOSTNAME_PATTERN + ")" + ":" + "(?:[0-9]{1,5})"
302
    PRIORITY_PATTERN = "(?P<priority>!-?\d+)?"
303
    DB_NAME = re.compile("^" + DB_NAME_PATTERN + PRIORITY_PATTERN + "$")
304
    URL = re.compile("^" + URL_PATTERN + PRIORITY_PATTERN + "$")
305
    URL_DB_NAME = re.compile("^(?P<url>" + URL_PATTERN + ")" + ":" +
306
                             "(?P<db_name>" + DB_NAME_PATTERN + ")" +
307
                             PRIORITY_PATTERN + "$")
308
309
    @classmethod
310
    def apply(cls, args, run):
311
        url, db_name, collection, priority = cls.parse_mongo_db_arg(args)
312
        if collection:
313
            mongo = MongoObserver.create(db_name=db_name, url=url,
314
                                         collection=collection,
315
                                         priority=priority)
316
        else:
317
            mongo = MongoObserver.create(db_name=db_name, url=url,
318
                                         priority=priority)
319
320
        run.observers.append(mongo)
321
322
    @classmethod
323
    def parse_mongo_db_arg(cls, mongo_db):
324
        def get_priority(pattern):
325
            prio_str = pattern.match(mongo_db).group('priority')
326
            if prio_str is None:
327
                return DEFAULT_MONGO_PRIORITY
328
            else:
329
                return int(prio_str[1:])
330
331
        if cls.DB_NAME.match(mongo_db):
332
            db_name, _, collection = mongo_db.partition('.')
333
            return ('localhost:27017', db_name, collection,
334
                    get_priority(cls.DB_NAME))
335
        elif cls.URL.match(mongo_db):
336
            return mongo_db, 'sacred', '', get_priority(cls.URL)
337
        elif cls.URL_DB_NAME.match(mongo_db):
338
            match = cls.URL_DB_NAME.match(mongo_db)
339
            db_name, _, collection = match.group('db_name').partition('.')
340
            return (match.group('url'), db_name, collection,
341
                    get_priority(cls.URL_DB_NAME))
342
        else:
343
            raise ValueError('mongo_db argument must have the form "db_name" '
344
                             'or "host:port[:db_name]" but was {}'
345
                             .format(mongo_db))
346