Completed
Push — master ( dbc38f...56accc )
by Klaus
01:34
created

NdArraySerializer   A

Complexity

Total Complexity 2

Size/Duplication

Total Lines 8
Duplicated Lines 0 %

Importance

Changes 3
Bugs 1 Features 0
Metric Value
dl 0
loc 8
rs 10
c 3
b 1
f 0
wmc 2

2 Methods

Rating   Name   Duplication   Size   Complexity  
A decode() 0 2 1
A encode() 0 2 1
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import (division, print_function, unicode_literals,
4
                        absolute_import)
5
6
import os
7
import datetime as dt
8
import json
9
import uuid
10
import textwrap
11
from collections import OrderedDict
12
13
from io import BufferedReader, FileIO
14
15
from sacred.__about__ import __version__
16
from sacred.observers import RunObserver
17
from sacred.commandline_options import CommandLineOption
18
import sacred.optional as opt
19
20
from tinydb import TinyDB, Query
21
from tinydb.queries import QueryImpl
22
from hashfs import HashFS
23
from tinydb_serialization import Serializer, SerializationMiddleware
24
25
__sacred__ = True  # marks files that should be filtered from stack traces
26
27
# Set data type values for abstract properties in Serializers
28
series_type = opt.pandas.Series if opt.has_pandas else None
29
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None
30
ndarray_type = opt.np.ndarray if opt.has_numpy else None
31
32
33
class BufferedReaderWrapper(BufferedReader):
34
    """Custom wrapper to allow for copying of file handle.
35
36
    tinydb_serialisation currently does a deepcopy on all the content of the
37
    dictionary before serialisation. By default, file handles are not
38
    copiable so this wrapper is necessary to create a duplicate of the
39
    file handle passes in.
40
41
    Note that the file passed in will therefor remain open as the copy is the
42
    one that gets closed.
43
    """
44
45
    def __init__(self, f_obj):
46
        f_obj = FileIO(f_obj.name)
47
        super(BufferedReaderWrapper, self).__init__(f_obj)
48
49
    def __copy__(self):
50
        f = open(self.name, self.mode)
51
        return BufferedReaderWrapper(f)
52
53
    def __deepcopy__(self, memo):
54
        f = open(self.name, self.mode)
55
        return BufferedReaderWrapper(f)
56
57
58
class DateTimeSerializer(Serializer):
59
    OBJ_CLASS = dt.datetime  # The class this serializer handles
60
61
    def encode(self, obj):
62
        return obj.strftime('%Y-%m-%dT%H:%M:%S.%f')
63
64
    def decode(self, s):
65
        return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f')
66
67
68
class NdArraySerializer(Serializer):
69
    OBJ_CLASS = ndarray_type
70
71
    def encode(self, obj):
72
        return json.dumps(obj.tolist(), check_circular=True)
73
74
    def decode(self, s):
75
        return opt.np.array(json.loads(s))
76
77
78
class DataFrameSerializer(Serializer):
79
    OBJ_CLASS = dataframe_type
80
81
    def encode(self, obj):
82
        return obj.to_json()
83
84
    def decode(self, s):
85
        return opt.pandas.read_json(s)
86
87
88
class SeriesSerializer(Serializer):
89
    OBJ_CLASS = series_type
90
91
    def encode(self, obj):
92
        return obj.to_json()
93
94
    def decode(self, s):
95
        return opt.pandas.read_json(s, typ='series')
96
97
98
class FileSerializer(Serializer):
99
    OBJ_CLASS = BufferedReaderWrapper
100
101
    def __init__(self, fs):
102
        self.fs = fs
103
104
    def encode(self, obj):
105
        address = self.fs.put(obj)
106
        return json.dumps(address.id)
107
108
    def decode(self, s):
109
        id_ = json.loads(s)
110
        file_reader = self.fs.open(id_)
111
        file_reader = BufferedReaderWrapper(file_reader)
112
        file_reader.hash = id_
113
        return file_reader
114
115
116
class TinyDbObserver(RunObserver):
117
118
    VERSION = "TinyDbObserver-{}".format(__version__)
119
120
    @staticmethod
121
    def create(path='./runs_db', overwrite=None):
122
123
        root_dir = os.path.abspath(path)
124
        if not os.path.exists(root_dir):
125
            os.makedirs(root_dir)
126
127
        fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
128
                    width=2, algorithm='md5')
129
130
        # Setup Serialisation object for non list/dict objects
131
        serialization_store = SerializationMiddleware()
132
        serialization_store.register_serializer(DateTimeSerializer(),
133
                                                'TinyDate')
134
        serialization_store.register_serializer(FileSerializer(fs),
135
                                                'TinyFile')
136
137
        if opt.has_numpy:
138
            serialization_store.register_serializer(NdArraySerializer(),
139
                                                    'TinyArray')
140
        if opt.has_pandas:
141
            serialization_store.register_serializer(DataFrameSerializer(),
142
                                                    'TinyDataFrame')
143
            serialization_store.register_serializer(SeriesSerializer(),
144
                                                    'TinySeries')
145
146
        db = TinyDB(os.path.join(root_dir, 'metadata.json'),
147
                    storage=serialization_store)
148
149
        return TinyDbObserver(db, fs, overwrite=overwrite, root=root_dir)
150
151
    def __init__(self, db, fs, overwrite=None, root=None):
152
        self.db = db
153
        self.runs = db.table('runs')
154
        self.fs = fs
155
        self.overwrite = overwrite
156
        self.run_entry = {}
157
        self.db_run_id = None
158
        self.root = root
159
160
    def save(self):
161
        """Insert or update the current run entry."""
162
        if self.db_run_id:
163
            self.runs.update(self.run_entry, eids=[self.db_run_id])
164
        else:
165
            db_run_id = self.runs.insert(self.run_entry)
166
            self.db_run_id = db_run_id
167
168
    def save_sources(self, ex_info):
169
170
        source_info = []
171
        for source_name, md5 in ex_info['sources']:
172
173
            # Substitute any HOME or Environment Vars to get absolute path
174
            abs_path = os.path.expanduser(source_name)
175
            abs_path = os.path.expandvars(source_name)
176
            handle = BufferedReaderWrapper(open(abs_path, 'rb'))
177
178
            file = self.fs.get(md5)
179
            if file:
180
                id_ = file.id
181
            else:
182
                address = self.fs.put(abs_path)
183
                id_ = address.id
184
            source_info.append([source_name, id_, handle])
185
        return source_info
186
187
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
188
                     _id):
189
        raise NotImplementedError('queued_event method is not implimented for'
190
                                  ' local TinyDbObserver.')
191
192
    def started_event(self, ex_info, command, host_info, start_time, config,
193
                      meta_info, _id):
194
195
        self.run_entry = {
196
            'experiment': dict(ex_info),
197
            'format': self.VERSION,
198
            'command': command,
199
            'host': dict(host_info),
200
            'start_time': start_time,
201
            'config': config,
202
            'meta': meta_info,
203
            'status': 'RUNNING',
204
            'resources': [],
205
            'artifacts': [],
206
            'captured_out': '',
207
            'info': {},
208
            'heartbeat': None
209
        }
210
211
        # set ID if not given
212
        if _id is None:
213
            _id = uuid.uuid4().hex
214
215
        self.run_entry['_id'] = _id
216
217
        # save sources
218
        self.run_entry['experiment']['sources'] = self.save_sources(ex_info)
219
        self.save()
220
        return self.run_entry['_id']
221
222
    def heartbeat_event(self, info, captured_out, beat_time):
223
        self.run_entry['info'] = info
224
        self.run_entry['captured_out'] = captured_out
225
        self.run_entry['heartbeat'] = beat_time
226
        self.save()
227
228
    def completed_event(self, stop_time, result):
229
        self.run_entry['stop_time'] = stop_time
230
        self.run_entry['result'] = result
231
        self.run_entry['status'] = 'COMPLETED'
232
        self.save()
233
234
    def interrupted_event(self, interrupt_time, status):
235
        self.run_entry['stop_time'] = interrupt_time
236
        self.run_entry['status'] = status
237
        self.save()
238
239
    def failed_event(self, fail_time, fail_trace):
240
        self.run_entry['stop_time'] = fail_time
241
        self.run_entry['status'] = 'FAILED'
242
        self.run_entry['fail_trace'] = fail_trace
243
        self.save()
244
245
    def resource_event(self, filename):
246
247
        id_ = self.fs.put(filename).id
248
        handle = BufferedReaderWrapper(open(filename, 'rb'))
249
        resource = [filename, id_, handle]
250
251
        if resource not in self.run_entry['resources']:
252
            self.run_entry['resources'].append(resource)
253
            self.save()
254
255
    def artifact_event(self, name, filename):
256
257
        id_ = self.fs.put(filename).id
258
        handle = BufferedReaderWrapper(open(filename, 'rb'))
259
        artifact = [name, filename, id_, handle]
260
261
        if artifact not in self.run_entry['artifacts']:
262
            self.run_entry['artifacts'].append(artifact)
263
            self.save()
264
265
    def __eq__(self, other):
266
        if isinstance(other, TinyDbObserver):
267
            return self.runs.all() == other.runs.all()
268
        return False
269
270
    def __ne__(self, other):
271
        return not self.__eq__(other)
272
273
274
class TinyDbOption(CommandLineOption):
275
    """Add a TinyDB Observer to the experiment."""
276
277
    arg = 'BASEDIR'
278
279
    @classmethod
280
    def apply(cls, args, run):
281
        location = cls.parse_tinydb_arg(args)
282
        tinydb_obs = TinyDbObserver.create(path=location)
283
        run.observers.append(tinydb_obs)
284
285
    @classmethod
286
    def parse_tinydb_arg(cls, args):
287
        return args
288
289
290
class TinyDbReader(object):
291
292
    def __init__(self, path):
293
294
        root_dir = os.path.abspath(path)
295
        if not os.path.exists(root_dir):
296
            raise IOError('Path does not exist: %s' % path)
297
298
        fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3,
299
                    width=2, algorithm='md5')
300
301
        # Setup Serialisation for non list/dict objects
302
        serialization_store = SerializationMiddleware()
303
        serialization_store.register_serializer(DateTimeSerializer(),
304
                                                'TinyDate')
305
        serialization_store.register_serializer(FileSerializer(fs),
306
                                                'TinyFile')
307
        if opt.has_numpy:
308
            serialization_store.register_serializer(NdArraySerializer(),
309
                                                    'TinyArray')
310
        if opt.has_pandas:
311
            serialization_store.register_serializer(DataFrameSerializer(),
312
                                                    'TinyDataFrame')
313
            serialization_store.register_serializer(SeriesSerializer(),
314
                                                    'TinySeries')
315
316
        db = TinyDB(os.path.join(root_dir, 'metadata.json'),
317
                    storage=serialization_store)
318
319
        self.db = db
320
        self.runs = db.table('runs')
321
        self.fs = fs
322
323
    def search(self, *args, **kwargs):
324
        """Wrapper to TinyDB's search function."""
325
        return self.runs.search(*args, **kwargs)
326
327
    def fetch_files(self, exp_name=None, query=None, indices=None):
328
        """Return Dictionary of files for experiment name or query.
329
330
        Returns a list of one dictionary per matched experiment. The
331
        dictionary is of the following structure
332
333
            {
334
              'exp_name': 'scascasc',
335
              'exp_id': 'dqwdqdqwf',
336
              'date': datatime_object,
337
              'sources': [ {'filename': filehandle}, ..., ],
338
              'resources': [ {'filename': filehandle}, ..., ],
339
              'artifacts': [ {'filename': filehandle}, ..., ]
340
            }
341
342
        """
343
        entries = self.fetch_metadata(exp_name, query, indices)
344
345
        all_matched_entries = []
346
        for ent in entries:
347
348
            rec = dict(exp_name=ent['experiment']['name'],
349
                       exp_id=ent['_id'],
350
                       date=ent['start_time'])
351
352
            source_files = {x[0]: x[2] for x in ent['experiment']['sources']}
353
            resource_files = {x[0]: x[2] for x in ent['resources']}
354
            artifact_files = {x[0]: x[3] for x in ent['artifacts']}
355
356
            if source_files:
357
                rec['sources'] = source_files
358
            if resource_files:
359
                rec['resources'] = resource_files
360
            if artifact_files:
361
                rec['artifacts'] = artifact_files
362
363
            all_matched_entries.append(rec)
364
365
        return all_matched_entries
366
367
    def fetch_report(self, exp_name=None, query=None, indices=None):
368
369
        template = """
370
-------------------------------------------------
371
Experiment: {exp_name}
372
-------------------------------------------------
373
ID: {exp_id}
374
Date: {start_date}    Duration: {duration}
375
376
Parameters:
377
{parameters}
378
379
Result:
380
{result}
381
382
Dependencies:
383
{dependencies}
384
385
Resources:
386
{resources}
387
388
Source Files:
389
{sources}
390
391
Outputs:
392
{artifacts}
393
"""
394
395
        entries = self.fetch_metadata(exp_name, query, indices)
396
397
        all_matched_entries = []
398
        for ent in entries:
399
400
            date = ent['start_time']
401
            weekdays = 'Mon Tue Wed Thu Fri Sat Sun'.split()
402
            w = weekdays[date.weekday()]
403
            date = ' '.join([w, date.strftime('%d %b %Y')])
404
405
            duration = ent['stop_time'] - ent['start_time']
406
            secs = duration.total_seconds()
407
            hours, remainder = divmod(secs, 3600)
408
            minutes, seconds = divmod(remainder, 60)
409
            duration = '%02d:%02d:%04.1f' % (hours, minutes, seconds)
410
411
            parameters = self._dict_to_indented_list(ent['config'])
412
413
            result = self._indent(ent['result'].__repr__(), prefix='    ')
414
415
            deps = ent['experiment']['dependencies']
416
            deps = self._indent('\n'.join(deps), prefix='    ')
417
418
            resources = [x[0] for x in ent['resources']]
419
            resources = self._indent('\n'.join(resources), prefix='    ')
420
421
            sources = [x[0] for x in ent['experiment']['sources']]
422
            sources = self._indent('\n'.join(sources), prefix='    ')
423
424
            artifacts = [x[0] for x in ent['artifacts']]
425
            artifacts = self._indent('\n'.join(artifacts), prefix='    ')
426
427
            none_str = '    None'
428
429
            rec = dict(exp_name=ent['experiment']['name'],
430
                       exp_id=ent['_id'],
431
                       start_date=date,
432
                       duration=duration,
433
                       parameters=parameters if parameters else none_str,
434
                       result=result if result else none_str,
435
                       dependencies=deps if deps else none_str,
436
                       resources=resources if resources else none_str,
437
                       sources=sources if sources else none_str,
438
                       artifacts=artifacts if artifacts else none_str)
439
440
            report = template.format(**rec)
441
442
            all_matched_entries.append(report)
443
444
        return all_matched_entries
445
446
    def fetch_metadata(self, exp_name=None, query=None, indices=None):
447
        """Return all metadata for matching experiment name, index or query."""
448
        if exp_name or query:
449
            if query:
450
                assert type(query), QueryImpl
451
                q = query
452
            elif exp_name:
453
                q = Query().experiment.name.search(exp_name)
454
455
            entries = self.runs.search(q)
456
457
        elif indices or indices == 0:
458
            if not isinstance(indices, (tuple, list)):
459
                indices = [indices]
460
461
            num_recs = len(self.runs)
462
463
            for idx in indices:
464
                if idx >= num_recs:
465
                    raise ValueError(
466
                        'Index value ({}) must be less than '
467
                        'number of records ({})'.format(idx, num_recs))
468
469
            entries = [self.runs.all()[ind] for ind in indices]
470
471
        else:
472
            raise ValueError('Must specify an experiment name, indicies or '
473
                             'pass custom query')
474
475
        return entries
476
477
    def _dict_to_indented_list(self, d):
478
479
        d = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
480
481
        output_str = ''
482
483
        for k, v in d.items():
484
            output_str += '%s: %s' % (k, v)
485
            output_str += '\n'
486
487
        output_str = self._indent(output_str.strip(), prefix='    ')
488
489
        return output_str
490
491
    def _indent(self, message, prefix):
492
        """Wrapper for indenting strings in Python 2 and 3."""
493
        preferred_width = 150
494
        wrapper = textwrap.TextWrapper(initial_indent=prefix,
495
                                       width=preferred_width,
496
                                       subsequent_indent=prefix)
497
498
        lines = message.splitlines()
499
        formatted_lines = [wrapper.fill(lin) for lin in lines]
500
        formatted_text = '\n'.join(formatted_lines)
501
502
        return formatted_text
503