TinyDbObserver.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 8
rs 10
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import (division, print_function, unicode_literals,
4
                        absolute_import)
5
6
import os
7
import datetime as dt
8
import json
9
import uuid
10
import textwrap
11
from collections import OrderedDict
12
13
from io import BufferedReader, FileIO
14
15
from sacred.__about__ import __version__
16
from sacred.observers import RunObserver
17
from sacred.commandline_options import CommandLineOption
18
import sacred.optional as opt
19
20
# Set data type values for abstract properties in Serializers
21
series_type = opt.pandas.Series if opt.has_pandas else None
22
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None
23
ndarray_type = opt.np.ndarray if opt.has_numpy else None
24
25
26
class BufferedReaderWrapper(BufferedReader):
27
    """Custom wrapper to allow for copying of file handle.
28
29
    tinydb_serialisation currently does a deepcopy on all the content of the
30
    dictionary before serialisation. By default, file handles are not
31
    copiable so this wrapper is necessary to create a duplicate of the
32
    file handle passes in.
33
34
    Note that the file passed in will therefor remain open as the copy is the
35
    one that gets closed.
36
    """
37
38
    def __init__(self, f_obj):
39
        f_obj = FileIO(f_obj.name)
40
        super(BufferedReaderWrapper, self).__init__(f_obj)
41
42
    def __copy__(self):
43
        f = open(self.name, self.mode)
44
        return BufferedReaderWrapper(f)
45
46
    def __deepcopy__(self, memo):
47
        f = open(self.name, self.mode)
48
        return BufferedReaderWrapper(f)
49
50
51
class TinyDbObserver(RunObserver):
52
53
    VERSION = "TinyDbObserver-{}".format(__version__)
54
55
    @staticmethod
56
    def create(path='./runs_db', overwrite=None):
57
58
        root_dir = os.path.abspath(path)
59
        if not os.path.exists(root_dir):
60
            os.makedirs(root_dir)
61
62
        fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
63
                    width=2, algorithm='md5')
64
65
        # Setup Serialisation object for non list/dict objects
66
        serialization_store = SerializationMiddleware()
67
        serialization_store.register_serializer(DateTimeSerializer(),
68
                                                'TinyDate')
69
        serialization_store.register_serializer(FileSerializer(fs),
70
                                                'TinyFile')
71
72
        if opt.has_numpy:
73
            serialization_store.register_serializer(NdArraySerializer(),
74
                                                    'TinyArray')
75
        if opt.has_pandas:
76
            serialization_store.register_serializer(DataFrameSerializer(),
77
                                                    'TinyDataFrame')
78
            serialization_store.register_serializer(SeriesSerializer(),
79
                                                    'TinySeries')
80
81
        db = TinyDB(os.path.join(root_dir, 'metadata.json'),
82
                    storage=serialization_store)
83
84
        return TinyDbObserver(db, fs, overwrite=overwrite, root=root_dir)
85
86
    def __init__(self, db, fs, overwrite=None, root=None):
87
        self.db = db
88
        self.runs = db.table('runs')
89
        self.fs = fs
90
        self.overwrite = overwrite
91
        self.run_entry = {}
92
        self.db_run_id = None
93
        self.root = root
94
95
    def save(self):
96
        """Insert or update the current run entry."""
97
        if self.db_run_id:
98
            self.runs.update(self.run_entry, eids=[self.db_run_id])
99
        else:
100
            db_run_id = self.runs.insert(self.run_entry)
101
            self.db_run_id = db_run_id
102
103
    def save_sources(self, ex_info):
104
105
        source_info = []
106
        for source_name, md5 in ex_info['sources']:
107
108
            # Substitute any HOME or Environment Vars to get absolute path
109
            abs_path = os.path.join(ex_info['base_dir'], source_name)
110
            abs_path = os.path.expanduser(abs_path)
111
            abs_path = os.path.expandvars(abs_path)
112
            handle = BufferedReaderWrapper(open(abs_path, 'rb'))
113
114
            file = self.fs.get(md5)
115
            if file:
116
                id_ = file.id
117
            else:
118
                address = self.fs.put(abs_path)
119
                id_ = address.id
120
            source_info.append([source_name, id_, handle])
121
        return source_info
122
123
    def queued_event(self, ex_info, command, host_info, queue_time, config,
124
                     meta_info, _id):
125
        raise NotImplementedError('queued_event method is not implemented for'
126
                                  ' local TinyDbObserver.')
127
128
    def started_event(self, ex_info, command, host_info, start_time, config,
129
                      meta_info, _id):
130
131
        self.run_entry = {
132
            'experiment': dict(ex_info),
133
            'format': self.VERSION,
134
            'command': command,
135
            'host': dict(host_info),
136
            'start_time': start_time,
137
            'config': config,
138
            'meta': meta_info,
139
            'status': 'RUNNING',
140
            'resources': [],
141
            'artifacts': [],
142
            'captured_out': '',
143
            'info': {},
144
            'heartbeat': None
145
        }
146
147
        # set ID if not given
148
        if _id is None:
149
            _id = uuid.uuid4().hex
150
151
        self.run_entry['_id'] = _id
152
153
        # save sources
154
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
155
        self.save()
156
        return self.run_entry['_id']
157
158
    def heartbeat_event(self, info, captured_out, beat_time, result):
159
        self.run_entry['info'] = info
160
        self.run_entry['captured_out'] = captured_out
161
        self.run_entry['heartbeat'] = beat_time
162
        self.run_entry['result'] = result
163
        self.save()
164
165
    def completed_event(self, stop_time, result):
166
        self.run_entry['stop_time'] = stop_time
167
        self.run_entry['result'] = result
168
        self.run_entry['status'] = 'COMPLETED'
169
        self.save()
170
171
    def interrupted_event(self, interrupt_time, status):
172
        self.run_entry['stop_time'] = interrupt_time
173
        self.run_entry['status'] = status
174
        self.save()
175
176
    def failed_event(self, fail_time, fail_trace):
177
        self.run_entry['stop_time'] = fail_time
178
        self.run_entry['status'] = 'FAILED'
179
        self.run_entry['fail_trace'] = fail_trace
180
        self.save()
181
182
    def resource_event(self, filename):
183
184
        id_ = self.fs.put(filename).id
185
        handle = BufferedReaderWrapper(open(filename, 'rb'))
186
        resource = [filename, id_, handle]
187
188
        if resource not in self.run_entry['resources']:
189
            self.run_entry['resources'].append(resource)
190
            self.save()
191
192
    def artifact_event(self, name, filename):
193
194
        id_ = self.fs.put(filename).id
195
        handle = BufferedReaderWrapper(open(filename, 'rb'))
196
        artifact = [name, filename, id_, handle]
197
198
        if artifact not in self.run_entry['artifacts']:
199
            self.run_entry['artifacts'].append(artifact)
200
            self.save()
201
202
    def __eq__(self, other):
203
        if isinstance(other, TinyDbObserver):
204
            return self.runs.all() == other.runs.all()
205
        return False
206
207
    def __ne__(self, other):
208
        return not self.__eq__(other)
209
210
211
class TinyDbOption(CommandLineOption):
212
    """Add a TinyDB Observer to the experiment."""
213
214
    __depends_on__ = ['tinydb', 'hashfs',
215
                      'tinydb_serialization#tinydb-serialization']
216
217
    arg = 'BASEDIR'
218
219
    @classmethod
220
    def apply(cls, args, run):
221
        location = cls.parse_tinydb_arg(args)
222
        tinydb_obs = TinyDbObserver.create(path=location)
223
        run.observers.append(tinydb_obs)
224
225
    @classmethod
226
    def parse_tinydb_arg(cls, args):
227
        return args
228
229
230
class TinyDbReader(object):
231
232
    def __init__(self, path):
233
234
        root_dir = os.path.abspath(path)
235
        if not os.path.exists(root_dir):
236
            raise IOError('Path does not exist: %s' % path)
237
238
        fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
239
                    width=2, algorithm='md5')
240
241
        # Setup Serialisation for non list/dict objects
242
        serialization_store = SerializationMiddleware()
243
        serialization_store.register_serializer(DateTimeSerializer(),
244
                                                'TinyDate')
245
        serialization_store.register_serializer(FileSerializer(fs),
246
                                                'TinyFile')
247
        if opt.has_numpy:
248
            serialization_store.register_serializer(NdArraySerializer(),
249
                                                    'TinyArray')
250
        if opt.has_pandas:
251
            serialization_store.register_serializer(DataFrameSerializer(),
252
                                                    'TinyDataFrame')
253
            serialization_store.register_serializer(SeriesSerializer(),
254
                                                    'TinySeries')
255
256
        db = TinyDB(os.path.join(root_dir, 'metadata.json'),
257
                    storage=serialization_store)
258
259
        self.db = db
260
        self.runs = db.table('runs')
261
        self.fs = fs
262
263
    def search(self, *args, **kwargs):
264
        """Wrapper to TinyDB's search function."""
265
        return self.runs.search(*args, **kwargs)
266
267
    def fetch_files(self, exp_name=None, query=None, indices=None):
268
        """Return Dictionary of files for experiment name or query.
269
270
        Returns a list of one dictionary per matched experiment. The
271
        dictionary is of the following structure
272
273
            {
274
              'exp_name': 'scascasc',
275
              'exp_id': 'dqwdqdqwf',
276
              'date': datatime_object,
277
              'sources': [ {'filename': filehandle}, ..., ],
278
              'resources': [ {'filename': filehandle}, ..., ],
279
              'artifacts': [ {'filename': filehandle}, ..., ]
280
            }
281
282
        """
283
        entries = self.fetch_metadata(exp_name, query, indices)
284
285
        all_matched_entries = []
286
        for ent in entries:
287
288
            rec = dict(exp_name=ent['experiment']['name'],
289
                       exp_id=ent['_id'],
290
                       date=ent['start_time'])
291
292
            source_files = {x[0]: x[2] for x in ent['experiment']['sources']}
293
            resource_files = {x[0]: x[2] for x in ent['resources']}
294
            artifact_files = {x[0]: x[3] for x in ent['artifacts']}
295
296
            if source_files:
297
                rec['sources'] = source_files
298
            if resource_files:
299
                rec['resources'] = resource_files
300
            if artifact_files:
301
                rec['artifacts'] = artifact_files
302
303
            all_matched_entries.append(rec)
304
305
        return all_matched_entries
306
307
    def fetch_report(self, exp_name=None, query=None, indices=None):
308
309
        template = """
310
-------------------------------------------------
311
Experiment: {exp_name}
312
-------------------------------------------------
313
ID: {exp_id}
314
Date: {start_date}    Duration: {duration}
315
316
Parameters:
317
{parameters}
318
319
Result:
320
{result}
321
322
Dependencies:
323
{dependencies}
324
325
Resources:
326
{resources}
327
328
Source Files:
329
{sources}
330
331
Outputs:
332
{artifacts}
333
"""
334
335
        entries = self.fetch_metadata(exp_name, query, indices)
336
337
        all_matched_entries = []
338
        for ent in entries:
339
340
            date = ent['start_time']
341
            weekdays = 'Mon Tue Wed Thu Fri Sat Sun'.split()
342
            w = weekdays[date.weekday()]
343
            date = ' '.join([w, date.strftime('%d %b %Y')])
344
345
            duration = ent['stop_time'] - ent['start_time']
346
            secs = duration.total_seconds()
347
            hours, remainder = divmod(secs, 3600)
348
            minutes, seconds = divmod(remainder, 60)
349
            duration = '%02d:%02d:%04.1f' % (hours, minutes, seconds)
350
351
            parameters = self._dict_to_indented_list(ent['config'])
352
353
            result = self._indent(ent['result'].__repr__(), prefix='    ')
354
355
            deps = ent['experiment']['dependencies']
356
            deps = self._indent('\n'.join(deps), prefix='    ')
357
358
            resources = [x[0] for x in ent['resources']]
359
            resources = self._indent('\n'.join(resources), prefix='    ')
360
361
            sources = [x[0] for x in ent['experiment']['sources']]
362
            sources = self._indent('\n'.join(sources), prefix='    ')
363
364
            artifacts = [x[0] for x in ent['artifacts']]
365
            artifacts = self._indent('\n'.join(artifacts), prefix='    ')
366
367
            none_str = '    None'
368
369
            rec = dict(exp_name=ent['experiment']['name'],
370
                       exp_id=ent['_id'],
371
                       start_date=date,
372
                       duration=duration,
373
                       parameters=parameters if parameters else none_str,
374
                       result=result if result else none_str,
375
                       dependencies=deps if deps else none_str,
376
                       resources=resources if resources else none_str,
377
                       sources=sources if sources else none_str,
378
                       artifacts=artifacts if artifacts else none_str)
379
380
            report = template.format(**rec)
381
382
            all_matched_entries.append(report)
383
384
        return all_matched_entries
385
386
    def fetch_metadata(self, exp_name=None, query=None, indices=None):
387
        """Return all metadata for matching experiment name, index or query."""
388
        if exp_name or query:
389
            if query:
390
                assert type(query), QueryImpl
391
                q = query
392
            elif exp_name:
393
                q = Query().experiment.name.search(exp_name)
394
395
            entries = self.runs.search(q)
396
397
        elif indices or indices == 0:
398
            if not isinstance(indices, (tuple, list)):
399
                indices = [indices]
400
401
            num_recs = len(self.runs)
402
403
            for idx in indices:
404
                if idx >= num_recs:
405
                    raise ValueError(
406
                        'Index value ({}) must be less than '
407
                        'number of records ({})'.format(idx, num_recs))
408
409
            entries = [self.runs.all()[ind] for ind in indices]
410
411
        else:
412
            raise ValueError('Must specify an experiment name, indicies or '
413
                             'pass custom query')
414
415
        return entries
416
417
    def _dict_to_indented_list(self, d):
418
419
        d = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
420
421
        output_str = ''
422
423
        for k, v in d.items():
424
            output_str += '%s: %s' % (k, v)
425
            output_str += '\n'
426
427
        output_str = self._indent(output_str.strip(), prefix='    ')
428
429
        return output_str
430
431
    def _indent(self, message, prefix):
432
        """Wrapper for indenting strings in Python 2 and 3."""
433
        preferred_width = 150
434
        wrapper = textwrap.TextWrapper(initial_indent=prefix,
435
                                       width=preferred_width,
436
                                       subsequent_indent=prefix)
437
438
        lines = message.splitlines()
439
        formatted_lines = [wrapper.fill(lin) for lin in lines]
440
        formatted_text = '\n'.join(formatted_lines)
441
442
        return formatted_text
443
444
445
if opt.has_tinydb:  # noqa
446
    from tinydb import TinyDB, Query
447
    from tinydb.queries import QueryImpl
448
    from hashfs import HashFS
449
    from tinydb_serialization import Serializer, SerializationMiddleware
450
451
    class DateTimeSerializer(Serializer):
452
        OBJ_CLASS = dt.datetime  # The class this serializer handles
453
454
        def encode(self, obj):
455
            return obj.strftime('%Y-%m-%dT%H:%M:%S.%f')
456
457
        def decode(self, s):
458
            return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f')
459
460
    class NdArraySerializer(Serializer):
461
        OBJ_CLASS = ndarray_type
462
463
        def encode(self, obj):
464
            return json.dumps(obj.tolist(), check_circular=True)
465
466
        def decode(self, s):
467
            return opt.np.array(json.loads(s))
468
469
    class DataFrameSerializer(Serializer):
470
        OBJ_CLASS = dataframe_type
471
472
        def encode(self, obj):
473
            return obj.to_json()
474
475
        def decode(self, s):
476
            return opt.pandas.read_json(s)
477
478
    class SeriesSerializer(Serializer):
479
        OBJ_CLASS = series_type
480
481
        def encode(self, obj):
482
            return obj.to_json()
483
484
        def decode(self, s):
485
            return opt.pandas.read_json(s, typ='series')
486
487
    class FileSerializer(Serializer):
488
        OBJ_CLASS = BufferedReaderWrapper
489
490
        def __init__(self, fs):
491
            self.fs = fs
492
493
        def encode(self, obj):
494
            address = self.fs.put(obj)
495
            return json.dumps(address.id)
496
497
        def decode(self, s):
498
            id_ = json.loads(s)
499
            file_reader = self.fs.open(id_)
500
            file_reader = BufferedReaderWrapper(file_reader)
501
            file_reader.hash = id_
502
            return file_reader
503