Completed
Push — master ( 2d5fd3...e2acfa )
by Klaus
01:03
created

sacred.observers.PandasToJson   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 11
Duplicated Lines 0 %
Metric Value
dl 0
loc 11
rs 10
wmc 4

1 Method

Rating   Name   Duplication   Size   Complexity  
A transform_incoming() 0 8 4
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import os.path
6
import pickle
7
import re
8
import sys
9
import time
10
11
import bson
12
import gridfs
13
import pymongo
14
import sacred.optional as opt
15
from pymongo.errors import AutoReconnect, InvalidDocument
16
from pymongo.son_manipulator import SONManipulator
17
from sacred.commandline_options import CommandLineOption
18
from sacred.dependencies import get_digest
19
from sacred.observers.base import RunObserver
20
from sacred.utils import ObserverError
21
22
SON_MANIPULATORS = []
23
24
25
if opt.has_numpy:
26
    class NumpyArraysToList(SONManipulator):
27
        """Turn numpy array into nested lists to save in json."""
28
29
        def transform_incoming(self, son, collection):
30
            for (key, value) in son.items():
31
                if isinstance(value, opt.np.ndarray):
32
                    son[key] = value.tolist()
33
                elif isinstance(value, dict):
34
                    # Make sure we recurse into sub-docs
35
                    son[key] = self.transform_incoming(value, collection)
36
            return son
37
38
    SON_MANIPULATORS.append(NumpyArraysToList())
39
40
41
if opt.has_pandas:
42
    pd = opt.pandas
43
    import json
44
45
    class PandasToJson(SONManipulator):
46
        """Turn pandas structures into dictionaries to save in json."""
47
48
        def transform_incoming(self, son, collection):
49
            for (key, value) in son.items():
50
                if isinstance(value, (pd.Series, pd.DataFrame, pd.Panel)):
51
                    son[key] = json.loads(value.to_json())
52
                elif isinstance(value, dict):
53
                    # Make sure we recurse into sub-docs
54
                    son[key] = self.transform_incoming(value, collection)
55
            return son
56
57
    SON_MANIPULATORS.append(PandasToJson())
58
59
60
def force_valid_bson_key(key):
61
    key = str(key)
62
    if key.startswith('$'):
63
        key = '@' + key[1:]
64
    key = key.replace('.', ',')
65
    return key
66
67
68
def force_bson_encodeable(obj):
69
    if isinstance(obj, dict):
70
        try:
71
            bson.BSON.encode(obj, check_keys=True)
72
            return obj
73
        except bson.InvalidDocument:
74
            return {force_valid_bson_key(k): force_bson_encodeable(v)
75
                    for k, v in obj.items()}
76
77
    elif opt.has_numpy and isinstance(obj, opt.np.ndarray):
78
        return obj
79
    else:
80
        try:
81
            bson.BSON.encode({'dict_just_for_testing': obj})
82
            return obj
83
        except bson.InvalidDocument:
84
            return str(obj)
85
86
87
class MongoObserver(RunObserver):
88
    @staticmethod
89
    def create(url='localhost', db_name='sacred', prefix='default', **kwargs):
90
        client = pymongo.MongoClient(url, **kwargs)
91
        database = client[db_name]
92
        for manipulator in SON_MANIPULATORS:
93
            database.add_son_manipulator(manipulator)
94
        runs_collection = database[prefix + '.runs']
95
        fs = gridfs.GridFS(database, collection=prefix)
96
        return MongoObserver(runs_collection, fs)
97
98
    def __init__(self, runs_collection, fs):
99
        self.runs = runs_collection
100
        self.fs = fs
101
        self.run_entry = None
102
103
    def save(self):
104
        try:
105
            self.runs.save(self.run_entry)
106
        except AutoReconnect:
107
            pass  # just wait for the next save
108
        except InvalidDocument:
109
            raise ObserverError('Run contained an unserializable entry.'
110
                                '(most likely in the info)')
111
112
    def final_save(self, attempts=10):
113
        for i in range(attempts):
114
            try:
115
                self.runs.save(self.run_entry)
116
                return
117
            except AutoReconnect:
118
                if i < attempts - 1:
119
                    time.sleep(1)
120
            except InvalidDocument:
121
                self.run_entry = force_bson_encodeable(self.run_entry)
122
                print("Warning: Some of the entries of the run were not "
123
                      "BSON-serializable!\n They have been altered such that "
124
                      "they can be stored, but you should fix your experiment!"
125
                      "Most likely it is either the 'info' or the 'result'.",
126
                      file=sys.stderr)
127
128
        from tempfile import NamedTemporaryFile
129
        with NamedTemporaryFile(suffix='.pickle', delete=False,
130
                                prefix='sacred_mongo_fail_') as f:
131
            pickle.dump(self.run_entry, f)
132
            print("Warning: saving to MongoDB failed! "
133
                  "Stored experiment entry in '{}'".format(f.name),
134
                  file=sys.stderr)
135
136
    def started_event(self, ex_info, host_info, start_time, config, comment):
137
        self.run_entry = {
138
            'experiment': dict(ex_info),
139
            'host': dict(host_info),
140
            'start_time': start_time,
141
            'config': config,
142
            'comment': comment,
143
            'status': 'RUNNING',
144
            'resources': [],
145
            'artifacts': [],
146
            'captured_out': '',
147
            'info': {},
148
            'heartbeat': None
149
        }
150
151
        self.save()
152
        for source_name, md5 in ex_info['sources']:
153
            if not self.fs.exists(filename=source_name, md5=md5):
154
                with open(source_name, 'rb') as f:
155
                    self.fs.put(f, filename=source_name)
156
157
    def heartbeat_event(self, info, captured_out, beat_time):
158
        self.run_entry['info'] = info
159
        self.run_entry['captured_out'] = captured_out
160
        self.run_entry['heartbeat'] = beat_time
161
        self.save()
162
163
    def completed_event(self, stop_time, result):
164
        self.run_entry['stop_time'] = stop_time
165
        self.run_entry['result'] = result
166
        self.run_entry['status'] = 'COMPLETED'
167
        self.final_save(attempts=10)
168
169
    def interrupted_event(self, interrupt_time):
170
        self.run_entry['stop_time'] = interrupt_time
171
        self.run_entry['status'] = 'INTERRUPTED'
172
        self.final_save(attempts=3)
173
174
    def failed_event(self, fail_time, fail_trace):
175
        self.run_entry['stop_time'] = fail_time
176
        self.run_entry['status'] = 'FAILED'
177
        self.run_entry['fail_trace'] = fail_trace
178
        self.final_save(attempts=1)
179
180
    def resource_event(self, filename):
181
        if self.fs.exists(filename=filename):
182
            md5hash = get_digest(filename)
183
            if self.fs.exists(filename=filename, md5=md5hash):
184
                resource = (filename, md5hash)
185
                if resource not in self.run_entry['resources']:
186
                    self.run_entry['resources'].append(resource)
187
                    self.save()
188
                return
189
        with open(filename, 'rb') as f:
190
            file_id = self.fs.put(f, filename=filename)
191
        md5hash = self.fs.get(file_id).md5
192
        self.run_entry['resources'].append((filename, md5hash))
193
        self.save()
194
195
    def artifact_event(self, filename):
196
        with open(filename, 'rb') as f:
197
            head, tail = os.path.split(filename)
198
            run_id = self.run_entry['_id']
199
            db_filename = 'artifact://{}/{}/{}'.format(
200
                self.run_entry['experiment']['name'], run_id, tail)
201
            file_id = self.fs.put(f, filename=db_filename)
202
        self.run_entry['artifacts'].append(file_id)
203
        self.save()
204
205
    def __eq__(self, other):
206
        if isinstance(other, MongoObserver):
207
            return self.runs == other.runs
208
        return False
209
210
    def __ne__(self, other):
211
        return not self.__eq__(other)
212
213
214
class MongoDbOption(CommandLineOption):
215
    """Add a MongoDB Observer to the experiment."""
216
217
    arg = 'DB'
218
    arg_description = "Database specification. Can be " \
219
                      "[host:port:]db_name[.prefix]"
220
221
    DB_NAME_PATTERN = r"[_A-Za-z][0-9A-Za-z!#%&'()+\-;=@\[\]^_{}.]{0,63}"
222
    HOSTNAME_PATTERN = \
223
        r"(?=.{1,255}$)"\
224
        r"[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?"\
225
        r"(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*"\
226
        r"\.?"
227
    URL_PATTERN = "(?:" + HOSTNAME_PATTERN + ")" + ":" + "(?:[0-9]{1,5})"
228
229
    DB_NAME = re.compile("^" + DB_NAME_PATTERN + "$")
230
    URL = re.compile("^" + URL_PATTERN + "$")
231
    URL_DB_NAME = re.compile("^(?P<url>" + URL_PATTERN + ")" + ":" +
232
                             "(?P<db_name>" + DB_NAME_PATTERN + ")$")
233
234
    @classmethod
235
    def apply(cls, args, run):
236
        url, db_name, prefix = cls.parse_mongo_db_arg(args)
237
        if prefix:
238
            mongo = MongoObserver.create(db_name=db_name, url=url,
239
                                         prefix=prefix)
240
        else:
241
            mongo = MongoObserver.create(db_name=db_name, url=url)
242
243
        run.observers.append(mongo)
244
245
    @classmethod
246
    def parse_mongo_db_arg(cls, mongo_db):
247
        if cls.DB_NAME.match(mongo_db):
248
            db_name, _, prefix = mongo_db.partition('.')
249
            return 'localhost:27017', db_name, prefix
250
        elif cls.URL.match(mongo_db):
251
            return mongo_db, 'sacred', ''
252
        elif cls.URL_DB_NAME.match(mongo_db):
253
            match = cls.URL_DB_NAME.match(mongo_db)
254
            db_name, _, prefix = match.group('db_name').partition('.')
255
            return match.group('url'), db_name, prefix
256
        else:
257
            raise ValueError('mongo_db argument must have the form "db_name" '
258
                             'or "host:port[:db_name]" but was {}'
259
                             .format(mongo_db))
260