Completed
Push — master ( dbc38f...56accc )
by Klaus
01:34
created

PandasToJson   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 11
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
c 0
b 0
f 0
dl 0
loc 11
rs 10
wmc 4
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import pickle
6
import re
7
import os.path
8
import sys
9
import time
10
11
import bson
12
import gridfs
13
import pymongo
14
import sacred.optional as opt
15
from pymongo.errors import AutoReconnect, InvalidDocument, DuplicateKeyError
16
from sacred.commandline_options import CommandLineOption
17
from sacred.dependencies import get_digest
18
from sacred.observers.base import RunObserver
19
from sacred.serializer import flatten
20
from sacred.utils import ObserverError
21
22
23
def force_valid_bson_key(key):
24
    key = str(key)
25
    if key.startswith('$'):
26
        key = '@' + key[1:]
27
    key = key.replace('.', ',')
28
    return key
29
30
31
def force_bson_encodeable(obj):
32
    if isinstance(obj, dict):
33
        try:
34
            bson.BSON.encode(obj, check_keys=True)
35
            return obj
36
        except bson.InvalidDocument:
37
            return {force_valid_bson_key(k): force_bson_encodeable(v)
38
                    for k, v in obj.items()}
39
40
    elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
41
        return obj
42
    else:
43
        try:
44
            bson.BSON.encode({'dict_just_for_testing': obj})
45
            return obj
46
        except bson.InvalidDocument:
47
            return str(obj)
48
49
50
class MongoObserver(RunObserver):
51
    COLLECTION_NAME_BLACKLIST = {'fs.files', 'fs.chunks', '_properties',
52
                                 'system.indexes', 'seach_space'}
53
    VERSION = 'MongoObserver-0.7.0'
54
55
    @staticmethod
56
    def create(url='localhost', db_name='sacred', collection='runs',
57
               overwrite=None, **kwargs):
58
        client = pymongo.MongoClient(url, **kwargs)
59
        database = client[db_name]
60
        if collection in MongoObserver.COLLECTION_NAME_BLACKLIST:
61
            raise KeyError('Collection name "{}" is reserved. '
62
                           'Please use a different one.'.format(collection))
63
        runs_collection = database[collection]
64
        fs = gridfs.GridFS(database)
65
        return MongoObserver(runs_collection, fs, overwrite=overwrite)
66
67
    def __init__(self, runs_collection, fs, overwrite=None):
68
        self.runs = runs_collection
69
        self.fs = fs
70
        self.overwrite = overwrite
71
        self.run_entry = None
72
73
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
74
                     _id):
75
        if self.overwrite is not None:
76
            raise RuntimeError("Can't overwrite with QUEUED run.")
77
        self.run_entry = {
78
            'experiment': dict(ex_info),
79
            'command': command,
80
            'config': flatten(config),
81
            'meta': meta_info,
82
            'status': 'QUEUED'
83
        }
84
        # set ID if given
85
        if _id is not None:
86
            self.run_entry['_id'] = _id
87
        # save sources
88
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
89
        self.final_save(attempts=1)
90
        return self.run_entry['_id']
91
92
    def started_event(self, ex_info, command, host_info, start_time, config,
93
                      meta_info, _id):
94
        if self.overwrite is None:
95
            self.run_entry = {'_id': _id}
96
        else:
97
            if self.run_entry is not None:
98
                raise RuntimeError("Cannot overwrite more than once!")
99
            # sanity checks
100
            if self.overwrite['experiment']['sources'] != ex_info['sources']:
101
                raise RuntimeError("Sources don't match")
102
            self.run_entry = self.overwrite
103
104
        self.run_entry.update({
105
            'experiment': dict(ex_info),
106
            'format': self.VERSION,
107
            'command': command,
108
            'host': dict(host_info),
109
            'start_time': start_time,
110
            'config': flatten(config),
111
            'meta': meta_info,
112
            'status': 'RUNNING',
113
            'resources': [],
114
            'artifacts': [],
115
            'captured_out': '',
116
            'info': {},
117
            'heartbeat': None
118
        })
119
120
        # save sources
121
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
122
        self.insert()
123
        return self.run_entry['_id']
124
125
    def heartbeat_event(self, info, captured_out, beat_time):
126
        self.run_entry['info'] = flatten(info)
127
        self.run_entry['captured_out'] = captured_out
128
        self.run_entry['heartbeat'] = beat_time
129
        self.save()
130
131
    def completed_event(self, stop_time, result):
132
        self.run_entry['stop_time'] = stop_time
133
        self.run_entry['result'] = result
134
        self.run_entry['status'] = 'COMPLETED'
135
        self.final_save(attempts=10)
136
137
    def interrupted_event(self, interrupt_time, status):
138
        self.run_entry['stop_time'] = interrupt_time
139
        self.run_entry['status'] = status
140
        self.final_save(attempts=3)
141
142
    def failed_event(self, fail_time, fail_trace):
143
        self.run_entry['stop_time'] = fail_time
144
        self.run_entry['status'] = 'FAILED'
145
        self.run_entry['fail_trace'] = fail_trace
146
        self.final_save(attempts=1)
147
148
    def resource_event(self, filename):
149
        if self.fs.exists(filename=filename):
150
            md5hash = get_digest(filename)
151
            if self.fs.exists(filename=filename, md5=md5hash):
152
                resource = (filename, md5hash)
153
                if resource not in self.run_entry['resources']:
154
                    self.run_entry['resources'].append(resource)
155
                    self.save()
156
                return
157
        with open(filename, 'rb') as f:
158
            file_id = self.fs.put(f, filename=filename)
159
        md5hash = self.fs.get(file_id).md5
160
        self.run_entry['resources'].append((filename, md5hash))
161
        self.save()
162
163
    def artifact_event(self, name, filename):
164
        with open(filename, 'rb') as f:
165
            run_id = self.run_entry['_id']
166
            db_filename = 'artifact://{}/{}/{}'.format(self.runs.name, run_id,
167
                                                       name)
168
            file_id = self.fs.put(f, filename=db_filename)
169
        self.run_entry['artifacts'].append({'name': name,
170
                                            'file_id': file_id})
171
        self.save()
172
173
    def insert(self):
174
        if self.overwrite:
175
            return self.save()
176
177
        autoinc_key = self.run_entry['_id'] is None
178
        while True:
179
            if autoinc_key:
180
                c = self.runs.find({}, {'_id': 1}).sort({'_id': -1}).limit(1)
181
                self.run_entry['_id'] = c.next()['_id'] + 1 if c.count() else 1
182
            try:
183
                self.runs.insert_one(self.run_entry)
184
            except InvalidDocument:
185
                raise ObserverError('Run contained an unserializable entry.'
186
                                    '(most likely in the info)')
187
            except DuplicateKeyError:
188
                if not autoinc_key:
189
                    raise
190
            return
191
192
    def save(self):
193
        try:
194
            self.runs.replace_one({'_id': self.run_entry['_id']},
195
                                  self.run_entry)
196
        except AutoReconnect:
197
            pass  # just wait for the next save
198
        except InvalidDocument:
199
            raise ObserverError('Run contained an unserializable entry.'
200
                                '(most likely in the info)')
201
202
    def final_save(self, attempts):
203
        for i in range(attempts):
204
            try:
205
                self.runs.save(self.run_entry)
206
                return
207
            except AutoReconnect:
208
                if i < attempts - 1:
209
                    time.sleep(1)
210
            except InvalidDocument:
211
                self.run_entry = force_bson_encodeable(self.run_entry)
212
                print("Warning: Some of the entries of the run were not "
213
                      "BSON-serializable!\n They have been altered such that "
214
                      "they can be stored, but you should fix your experiment!"
215
                      "Most likely it is either the 'info' or the 'result'.",
216
                      file=sys.stderr)
217
218
        from tempfile import NamedTemporaryFile
219
        with NamedTemporaryFile(suffix='.pickle', delete=False,
220
                                prefix='sacred_mongo_fail_') as f:
221
            pickle.dump(self.run_entry, f)
222
            print("Warning: saving to MongoDB failed! "
223
                  "Stored experiment entry in '{}'".format(f.name),
224
                  file=sys.stderr)
225
226
    def save_sources(self, ex_info):
227
        base_dir = ex_info['base_dir']
228
        source_info = []
229
        for source_name, md5 in ex_info['sources']:
230
            abs_path = os.path.join(base_dir, source_name)
231
            file = self.fs.find_one({'filename': abs_path, 'md5': md5})
232
            if file:
233
                _id = file._id
234
            else:
235
                with open(abs_path, 'rb') as f:
236
                    _id = self.fs.put(f, filename=abs_path)
237
            source_info.append([source_name, _id])
238
        return source_info
239
240
    def __eq__(self, other):
241
        if isinstance(other, MongoObserver):
242
            return self.runs == other.runs
243
        return False
244
245
    def __ne__(self, other):
246
        return not self.__eq__(other)
247
248
249
class MongoDbOption(CommandLineOption):
250
    """Add a MongoDB Observer to the experiment."""
251
252
    arg = 'DB'
253
    arg_description = "Database specification. Can be " \
254
                      "[host:port:]db_name[.collection]"
255
256
    DB_NAME_PATTERN = r"[_A-Za-z][0-9A-Za-z!#%&'()+\-;=@\[\]^_{}.]{0,63}"
257
    HOSTNAME_PATTERN = \
258
        r"(?=.{1,255}$)"\
259
        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"\
260
        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*"\
261
        r"\.?"
262
    URL_PATTERN = "(?:" + HOSTNAME_PATTERN + ")" + ":" + "(?:[0-9]{1,5})"
263
264
    DB_NAME = re.compile("^" + DB_NAME_PATTERN + "$")
265
    URL = re.compile("^" + URL_PATTERN + "$")
266
    URL_DB_NAME = re.compile("^(?P<url>" + URL_PATTERN + ")" + ":" +
267
                             "(?P<db_name>" + DB_NAME_PATTERN + ")$")
268
269
    @classmethod
270
    def apply(cls, args, run):
271
        url, db_name, collection = cls.parse_mongo_db_arg(args)
272
        if collection:
273
            mongo = MongoObserver.create(db_name=db_name, url=url,
274
                                         collection=collection)
275
        else:
276
            mongo = MongoObserver.create(db_name=db_name, url=url)
277
278
        run.observers.append(mongo)
279
280
    @classmethod
281
    def parse_mongo_db_arg(cls, mongo_db):
282
        if cls.DB_NAME.match(mongo_db):
283
            db_name, _, collection = mongo_db.partition('.')
284
            return 'localhost:27017', db_name, collection
285
        elif cls.URL.match(mongo_db):
286
            return mongo_db, 'sacred', ''
287
        elif cls.URL_DB_NAME.match(mongo_db):
288
            match = cls.URL_DB_NAME.match(mongo_db)
289
            db_name, _, collection = match.group('db_name').partition('.')
290
            return match.group('url'), db_name, collection
291
        else:
292
            raise ValueError('mongo_db argument must have the form "db_name" '
293
                             'or "host:port[:db_name]" but was {}'
294
                             .format(mongo_db))
295