1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# coding=utf-8 |
3
|
|
|
from __future__ import (division, print_function, unicode_literals, |
4
|
|
|
absolute_import) |
5
|
|
|
|
6
|
|
|
import os |
7
|
|
|
import datetime as dt |
8
|
|
|
import json |
9
|
|
|
import uuid |
10
|
|
|
import textwrap |
11
|
|
|
from collections import OrderedDict |
12
|
|
|
|
13
|
|
|
from io import BufferedReader, FileIO |
14
|
|
|
|
15
|
|
|
from sacred.__about__ import __version__ |
16
|
|
|
from sacred.observers import RunObserver |
17
|
|
|
from sacred.commandline_options import CommandLineOption |
18
|
|
|
import sacred.optional as opt |
19
|
|
|
|
20
|
|
|
# Set data type values for abstract properties in Serializers |
21
|
|
|
series_type = opt.pandas.Series if opt.has_pandas else None |
22
|
|
|
dataframe_type = opt.pandas.DataFrame if opt.has_pandas else None |
23
|
|
|
ndarray_type = opt.np.ndarray if opt.has_numpy else None |
24
|
|
|
|
25
|
|
|
|
26
|
|
|
class BufferedReaderWrapper(BufferedReader): |
27
|
|
|
"""Custom wrapper to allow for copying of file handle. |
28
|
|
|
|
29
|
|
|
tinydb_serialisation currently does a deepcopy on all the content of the |
30
|
|
|
dictionary before serialisation. By default, file handles are not |
31
|
|
|
copiable so this wrapper is necessary to create a duplicate of the |
32
|
|
|
file handle passes in. |
33
|
|
|
|
34
|
|
|
Note that the file passed in will therefor remain open as the copy is the |
35
|
|
|
one that gets closed. |
36
|
|
|
""" |
37
|
|
|
|
38
|
|
|
def __init__(self, f_obj): |
39
|
|
|
f_obj = FileIO(f_obj.name) |
40
|
|
|
super(BufferedReaderWrapper, self).__init__(f_obj) |
41
|
|
|
|
42
|
|
|
def __copy__(self): |
43
|
|
|
f = open(self.name, self.mode) |
44
|
|
|
return BufferedReaderWrapper(f) |
45
|
|
|
|
46
|
|
|
def __deepcopy__(self, memo): |
47
|
|
|
f = open(self.name, self.mode) |
48
|
|
|
return BufferedReaderWrapper(f) |
49
|
|
|
|
50
|
|
|
|
51
|
|
|
class TinyDbObserver(RunObserver): |
52
|
|
|
|
53
|
|
|
VERSION = "TinyDbObserver-{}".format(__version__) |
54
|
|
|
|
55
|
|
|
@staticmethod |
56
|
|
|
def create(path='./runs_db', overwrite=None): |
57
|
|
|
|
58
|
|
|
root_dir = os.path.abspath(path) |
59
|
|
|
if not os.path.exists(root_dir): |
60
|
|
|
os.makedirs(root_dir) |
61
|
|
|
|
62
|
|
|
fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3, |
63
|
|
|
width=2, algorithm='md5') |
64
|
|
|
|
65
|
|
|
# Setup Serialisation object for non list/dict objects |
66
|
|
|
serialization_store = SerializationMiddleware() |
67
|
|
|
serialization_store.register_serializer(DateTimeSerializer(), |
68
|
|
|
'TinyDate') |
69
|
|
|
serialization_store.register_serializer(FileSerializer(fs), |
70
|
|
|
'TinyFile') |
71
|
|
|
|
72
|
|
|
if opt.has_numpy: |
73
|
|
|
serialization_store.register_serializer(NdArraySerializer(), |
74
|
|
|
'TinyArray') |
75
|
|
|
if opt.has_pandas: |
76
|
|
|
serialization_store.register_serializer(DataFrameSerializer(), |
77
|
|
|
'TinyDataFrame') |
78
|
|
|
serialization_store.register_serializer(SeriesSerializer(), |
79
|
|
|
'TinySeries') |
80
|
|
|
|
81
|
|
|
db = TinyDB(os.path.join(root_dir, 'metadata.json'), |
82
|
|
|
storage=serialization_store) |
83
|
|
|
|
84
|
|
|
return TinyDbObserver(db, fs, overwrite=overwrite, root=root_dir) |
85
|
|
|
|
86
|
|
|
def __init__(self, db, fs, overwrite=None, root=None): |
87
|
|
|
self.db = db |
88
|
|
|
self.runs = db.table('runs') |
89
|
|
|
self.fs = fs |
90
|
|
|
self.overwrite = overwrite |
91
|
|
|
self.run_entry = {} |
92
|
|
|
self.db_run_id = None |
93
|
|
|
self.root = root |
94
|
|
|
|
95
|
|
|
def save(self): |
96
|
|
|
"""Insert or update the current run entry.""" |
97
|
|
|
if self.db_run_id: |
98
|
|
|
self.runs.update(self.run_entry, eids=[self.db_run_id]) |
99
|
|
|
else: |
100
|
|
|
db_run_id = self.runs.insert(self.run_entry) |
101
|
|
|
self.db_run_id = db_run_id |
102
|
|
|
|
103
|
|
|
def save_sources(self, ex_info): |
104
|
|
|
|
105
|
|
|
source_info = [] |
106
|
|
|
for source_name, md5 in ex_info['sources']: |
107
|
|
|
|
108
|
|
|
# Substitute any HOME or Environment Vars to get absolute path |
109
|
|
|
abs_path = os.path.join(ex_info['base_dir'], source_name) |
110
|
|
|
abs_path = os.path.expanduser(abs_path) |
111
|
|
|
abs_path = os.path.expandvars(abs_path) |
112
|
|
|
handle = BufferedReaderWrapper(open(abs_path, 'rb')) |
113
|
|
|
|
114
|
|
|
file = self.fs.get(md5) |
115
|
|
|
if file: |
116
|
|
|
id_ = file.id |
117
|
|
|
else: |
118
|
|
|
address = self.fs.put(abs_path) |
119
|
|
|
id_ = address.id |
120
|
|
|
source_info.append([source_name, id_, handle]) |
121
|
|
|
return source_info |
122
|
|
|
|
123
|
|
|
def queued_event(self, ex_info, command, host_info, queue_time, config, |
124
|
|
|
meta_info, _id): |
125
|
|
|
raise NotImplementedError('queued_event method is not implemented for' |
126
|
|
|
' local TinyDbObserver.') |
127
|
|
|
|
128
|
|
|
def started_event(self, ex_info, command, host_info, start_time, config, |
129
|
|
|
meta_info, _id): |
130
|
|
|
|
131
|
|
|
self.run_entry = { |
132
|
|
|
'experiment': dict(ex_info), |
133
|
|
|
'format': self.VERSION, |
134
|
|
|
'command': command, |
135
|
|
|
'host': dict(host_info), |
136
|
|
|
'start_time': start_time, |
137
|
|
|
'config': config, |
138
|
|
|
'meta': meta_info, |
139
|
|
|
'status': 'RUNNING', |
140
|
|
|
'resources': [], |
141
|
|
|
'artifacts': [], |
142
|
|
|
'captured_out': '', |
143
|
|
|
'info': {}, |
144
|
|
|
'heartbeat': None |
145
|
|
|
} |
146
|
|
|
|
147
|
|
|
# set ID if not given |
148
|
|
|
if _id is None: |
149
|
|
|
_id = uuid.uuid4().hex |
150
|
|
|
|
151
|
|
|
self.run_entry['_id'] = _id |
152
|
|
|
|
153
|
|
|
# save sources |
154
|
|
|
self.run_entry['experiment']['sources'] = self.save_sources(ex_info) |
155
|
|
|
self.save() |
156
|
|
|
return self.run_entry['_id'] |
157
|
|
|
|
158
|
|
|
def heartbeat_event(self, info, captured_out, beat_time, result): |
159
|
|
|
self.run_entry['info'] = info |
160
|
|
|
self.run_entry['captured_out'] = captured_out |
161
|
|
|
self.run_entry['heartbeat'] = beat_time |
162
|
|
|
self.run_entry['result'] = result |
163
|
|
|
self.save() |
164
|
|
|
|
165
|
|
|
def completed_event(self, stop_time, result): |
166
|
|
|
self.run_entry['stop_time'] = stop_time |
167
|
|
|
self.run_entry['result'] = result |
168
|
|
|
self.run_entry['status'] = 'COMPLETED' |
169
|
|
|
self.save() |
170
|
|
|
|
171
|
|
|
def interrupted_event(self, interrupt_time, status): |
172
|
|
|
self.run_entry['stop_time'] = interrupt_time |
173
|
|
|
self.run_entry['status'] = status |
174
|
|
|
self.save() |
175
|
|
|
|
176
|
|
|
def failed_event(self, fail_time, fail_trace): |
177
|
|
|
self.run_entry['stop_time'] = fail_time |
178
|
|
|
self.run_entry['status'] = 'FAILED' |
179
|
|
|
self.run_entry['fail_trace'] = fail_trace |
180
|
|
|
self.save() |
181
|
|
|
|
182
|
|
|
def resource_event(self, filename): |
183
|
|
|
|
184
|
|
|
id_ = self.fs.put(filename).id |
185
|
|
|
handle = BufferedReaderWrapper(open(filename, 'rb')) |
186
|
|
|
resource = [filename, id_, handle] |
187
|
|
|
|
188
|
|
|
if resource not in self.run_entry['resources']: |
189
|
|
|
self.run_entry['resources'].append(resource) |
190
|
|
|
self.save() |
191
|
|
|
|
192
|
|
|
def artifact_event(self, name, filename): |
193
|
|
|
|
194
|
|
|
id_ = self.fs.put(filename).id |
195
|
|
|
handle = BufferedReaderWrapper(open(filename, 'rb')) |
196
|
|
|
artifact = [name, filename, id_, handle] |
197
|
|
|
|
198
|
|
|
if artifact not in self.run_entry['artifacts']: |
199
|
|
|
self.run_entry['artifacts'].append(artifact) |
200
|
|
|
self.save() |
201
|
|
|
|
202
|
|
|
def __eq__(self, other): |
203
|
|
|
if isinstance(other, TinyDbObserver): |
204
|
|
|
return self.runs.all() == other.runs.all() |
205
|
|
|
return False |
206
|
|
|
|
207
|
|
|
def __ne__(self, other): |
208
|
|
|
return not self.__eq__(other) |
209
|
|
|
|
210
|
|
|
|
211
|
|
|
class TinyDbOption(CommandLineOption): |
212
|
|
|
"""Add a TinyDB Observer to the experiment.""" |
213
|
|
|
|
214
|
|
|
__depends_on__ = ['tinydb', 'hashfs', |
215
|
|
|
'tinydb_serialization#tinydb-serialization'] |
216
|
|
|
|
217
|
|
|
arg = 'BASEDIR' |
218
|
|
|
|
219
|
|
|
@classmethod |
220
|
|
|
def apply(cls, args, run): |
221
|
|
|
location = cls.parse_tinydb_arg(args) |
222
|
|
|
tinydb_obs = TinyDbObserver.create(path=location) |
223
|
|
|
run.observers.append(tinydb_obs) |
224
|
|
|
|
225
|
|
|
@classmethod |
226
|
|
|
def parse_tinydb_arg(cls, args): |
227
|
|
|
return args |
228
|
|
|
|
229
|
|
|
|
230
|
|
|
class TinyDbReader(object): |
231
|
|
|
|
232
|
|
|
def __init__(self, path): |
233
|
|
|
|
234
|
|
|
root_dir = os.path.abspath(path) |
235
|
|
|
if not os.path.exists(root_dir): |
236
|
|
|
raise IOError('Path does not exist: %s' % path) |
237
|
|
|
|
238
|
|
|
fs = HashFS(os.path.join(root_dir, 'hashfs'), depth=3, |
239
|
|
|
width=2, algorithm='md5') |
240
|
|
|
|
241
|
|
|
# Setup Serialisation for non list/dict objects |
242
|
|
|
serialization_store = SerializationMiddleware() |
243
|
|
|
serialization_store.register_serializer(DateTimeSerializer(), |
244
|
|
|
'TinyDate') |
245
|
|
|
serialization_store.register_serializer(FileSerializer(fs), |
246
|
|
|
'TinyFile') |
247
|
|
|
if opt.has_numpy: |
248
|
|
|
serialization_store.register_serializer(NdArraySerializer(), |
249
|
|
|
'TinyArray') |
250
|
|
|
if opt.has_pandas: |
251
|
|
|
serialization_store.register_serializer(DataFrameSerializer(), |
252
|
|
|
'TinyDataFrame') |
253
|
|
|
serialization_store.register_serializer(SeriesSerializer(), |
254
|
|
|
'TinySeries') |
255
|
|
|
|
256
|
|
|
db = TinyDB(os.path.join(root_dir, 'metadata.json'), |
257
|
|
|
storage=serialization_store) |
258
|
|
|
|
259
|
|
|
self.db = db |
260
|
|
|
self.runs = db.table('runs') |
261
|
|
|
self.fs = fs |
262
|
|
|
|
263
|
|
|
def search(self, *args, **kwargs): |
264
|
|
|
"""Wrapper to TinyDB's search function.""" |
265
|
|
|
return self.runs.search(*args, **kwargs) |
266
|
|
|
|
267
|
|
|
def fetch_files(self, exp_name=None, query=None, indices=None): |
268
|
|
|
"""Return Dictionary of files for experiment name or query. |
269
|
|
|
|
270
|
|
|
Returns a list of one dictionary per matched experiment. The |
271
|
|
|
dictionary is of the following structure |
272
|
|
|
|
273
|
|
|
{ |
274
|
|
|
'exp_name': 'scascasc', |
275
|
|
|
'exp_id': 'dqwdqdqwf', |
276
|
|
|
'date': datatime_object, |
277
|
|
|
'sources': [ {'filename': filehandle}, ..., ], |
278
|
|
|
'resources': [ {'filename': filehandle}, ..., ], |
279
|
|
|
'artifacts': [ {'filename': filehandle}, ..., ] |
280
|
|
|
} |
281
|
|
|
|
282
|
|
|
""" |
283
|
|
|
entries = self.fetch_metadata(exp_name, query, indices) |
284
|
|
|
|
285
|
|
|
all_matched_entries = [] |
286
|
|
|
for ent in entries: |
287
|
|
|
|
288
|
|
|
rec = dict(exp_name=ent['experiment']['name'], |
289
|
|
|
exp_id=ent['_id'], |
290
|
|
|
date=ent['start_time']) |
291
|
|
|
|
292
|
|
|
source_files = {x[0]: x[2] for x in ent['experiment']['sources']} |
293
|
|
|
resource_files = {x[0]: x[2] for x in ent['resources']} |
294
|
|
|
artifact_files = {x[0]: x[3] for x in ent['artifacts']} |
295
|
|
|
|
296
|
|
|
if source_files: |
297
|
|
|
rec['sources'] = source_files |
298
|
|
|
if resource_files: |
299
|
|
|
rec['resources'] = resource_files |
300
|
|
|
if artifact_files: |
301
|
|
|
rec['artifacts'] = artifact_files |
302
|
|
|
|
303
|
|
|
all_matched_entries.append(rec) |
304
|
|
|
|
305
|
|
|
return all_matched_entries |
306
|
|
|
|
307
|
|
|
def fetch_report(self, exp_name=None, query=None, indices=None): |
308
|
|
|
|
309
|
|
|
template = """ |
310
|
|
|
------------------------------------------------- |
311
|
|
|
Experiment: {exp_name} |
312
|
|
|
------------------------------------------------- |
313
|
|
|
ID: {exp_id} |
314
|
|
|
Date: {start_date} Duration: {duration} |
315
|
|
|
|
316
|
|
|
Parameters: |
317
|
|
|
{parameters} |
318
|
|
|
|
319
|
|
|
Result: |
320
|
|
|
{result} |
321
|
|
|
|
322
|
|
|
Dependencies: |
323
|
|
|
{dependencies} |
324
|
|
|
|
325
|
|
|
Resources: |
326
|
|
|
{resources} |
327
|
|
|
|
328
|
|
|
Source Files: |
329
|
|
|
{sources} |
330
|
|
|
|
331
|
|
|
Outputs: |
332
|
|
|
{artifacts} |
333
|
|
|
""" |
334
|
|
|
|
335
|
|
|
entries = self.fetch_metadata(exp_name, query, indices) |
336
|
|
|
|
337
|
|
|
all_matched_entries = [] |
338
|
|
|
for ent in entries: |
339
|
|
|
|
340
|
|
|
date = ent['start_time'] |
341
|
|
|
weekdays = 'Mon Tue Wed Thu Fri Sat Sun'.split() |
342
|
|
|
w = weekdays[date.weekday()] |
343
|
|
|
date = ' '.join([w, date.strftime('%d %b %Y')]) |
344
|
|
|
|
345
|
|
|
duration = ent['stop_time'] - ent['start_time'] |
346
|
|
|
secs = duration.total_seconds() |
347
|
|
|
hours, remainder = divmod(secs, 3600) |
348
|
|
|
minutes, seconds = divmod(remainder, 60) |
349
|
|
|
duration = '%02d:%02d:%04.1f' % (hours, minutes, seconds) |
350
|
|
|
|
351
|
|
|
parameters = self._dict_to_indented_list(ent['config']) |
352
|
|
|
|
353
|
|
|
result = self._indent(ent['result'].__repr__(), prefix=' ') |
354
|
|
|
|
355
|
|
|
deps = ent['experiment']['dependencies'] |
356
|
|
|
deps = self._indent('\n'.join(deps), prefix=' ') |
357
|
|
|
|
358
|
|
|
resources = [x[0] for x in ent['resources']] |
359
|
|
|
resources = self._indent('\n'.join(resources), prefix=' ') |
360
|
|
|
|
361
|
|
|
sources = [x[0] for x in ent['experiment']['sources']] |
362
|
|
|
sources = self._indent('\n'.join(sources), prefix=' ') |
363
|
|
|
|
364
|
|
|
artifacts = [x[0] for x in ent['artifacts']] |
365
|
|
|
artifacts = self._indent('\n'.join(artifacts), prefix=' ') |
366
|
|
|
|
367
|
|
|
none_str = ' None' |
368
|
|
|
|
369
|
|
|
rec = dict(exp_name=ent['experiment']['name'], |
370
|
|
|
exp_id=ent['_id'], |
371
|
|
|
start_date=date, |
372
|
|
|
duration=duration, |
373
|
|
|
parameters=parameters if parameters else none_str, |
374
|
|
|
result=result if result else none_str, |
375
|
|
|
dependencies=deps if deps else none_str, |
376
|
|
|
resources=resources if resources else none_str, |
377
|
|
|
sources=sources if sources else none_str, |
378
|
|
|
artifacts=artifacts if artifacts else none_str) |
379
|
|
|
|
380
|
|
|
report = template.format(**rec) |
381
|
|
|
|
382
|
|
|
all_matched_entries.append(report) |
383
|
|
|
|
384
|
|
|
return all_matched_entries |
385
|
|
|
|
386
|
|
|
def fetch_metadata(self, exp_name=None, query=None, indices=None): |
387
|
|
|
"""Return all metadata for matching experiment name, index or query.""" |
388
|
|
|
if exp_name or query: |
389
|
|
|
if query: |
390
|
|
|
assert type(query), QueryImpl |
391
|
|
|
q = query |
392
|
|
|
elif exp_name: |
393
|
|
|
q = Query().experiment.name.search(exp_name) |
394
|
|
|
|
395
|
|
|
entries = self.runs.search(q) |
396
|
|
|
|
397
|
|
|
elif indices or indices == 0: |
398
|
|
|
if not isinstance(indices, (tuple, list)): |
399
|
|
|
indices = [indices] |
400
|
|
|
|
401
|
|
|
num_recs = len(self.runs) |
402
|
|
|
|
403
|
|
|
for idx in indices: |
404
|
|
|
if idx >= num_recs: |
405
|
|
|
raise ValueError( |
406
|
|
|
'Index value ({}) must be less than ' |
407
|
|
|
'number of records ({})'.format(idx, num_recs)) |
408
|
|
|
|
409
|
|
|
entries = [self.runs.all()[ind] for ind in indices] |
410
|
|
|
|
411
|
|
|
else: |
412
|
|
|
raise ValueError('Must specify an experiment name, indicies or ' |
413
|
|
|
'pass custom query') |
414
|
|
|
|
415
|
|
|
return entries |
416
|
|
|
|
417
|
|
|
def _dict_to_indented_list(self, d): |
418
|
|
|
|
419
|
|
|
d = OrderedDict(sorted(d.items(), key=lambda t: t[0])) |
420
|
|
|
|
421
|
|
|
output_str = '' |
422
|
|
|
|
423
|
|
|
for k, v in d.items(): |
424
|
|
|
output_str += '%s: %s' % (k, v) |
425
|
|
|
output_str += '\n' |
426
|
|
|
|
427
|
|
|
output_str = self._indent(output_str.strip(), prefix=' ') |
428
|
|
|
|
429
|
|
|
return output_str |
430
|
|
|
|
431
|
|
|
def _indent(self, message, prefix): |
432
|
|
|
"""Wrapper for indenting strings in Python 2 and 3.""" |
433
|
|
|
preferred_width = 150 |
434
|
|
|
wrapper = textwrap.TextWrapper(initial_indent=prefix, |
435
|
|
|
width=preferred_width, |
436
|
|
|
subsequent_indent=prefix) |
437
|
|
|
|
438
|
|
|
lines = message.splitlines() |
439
|
|
|
formatted_lines = [wrapper.fill(lin) for lin in lines] |
440
|
|
|
formatted_text = '\n'.join(formatted_lines) |
441
|
|
|
|
442
|
|
|
return formatted_text |
443
|
|
|
|
444
|
|
|
|
445
|
|
|
if opt.has_tinydb: # noqa |
446
|
|
|
from tinydb import TinyDB, Query |
447
|
|
|
from tinydb.queries import QueryImpl |
448
|
|
|
from hashfs import HashFS |
449
|
|
|
from tinydb_serialization import Serializer, SerializationMiddleware |
450
|
|
|
|
451
|
|
|
class DateTimeSerializer(Serializer): |
452
|
|
|
OBJ_CLASS = dt.datetime # The class this serializer handles |
453
|
|
|
|
454
|
|
|
def encode(self, obj): |
455
|
|
|
return obj.strftime('%Y-%m-%dT%H:%M:%S.%f') |
456
|
|
|
|
457
|
|
|
def decode(self, s): |
458
|
|
|
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S.%f') |
459
|
|
|
|
460
|
|
|
class NdArraySerializer(Serializer): |
461
|
|
|
OBJ_CLASS = ndarray_type |
462
|
|
|
|
463
|
|
|
def encode(self, obj): |
464
|
|
|
return json.dumps(obj.tolist(), check_circular=True) |
465
|
|
|
|
466
|
|
|
def decode(self, s): |
467
|
|
|
return opt.np.array(json.loads(s)) |
468
|
|
|
|
469
|
|
|
class DataFrameSerializer(Serializer): |
470
|
|
|
OBJ_CLASS = dataframe_type |
471
|
|
|
|
472
|
|
|
def encode(self, obj): |
473
|
|
|
return obj.to_json() |
474
|
|
|
|
475
|
|
|
def decode(self, s): |
476
|
|
|
return opt.pandas.read_json(s) |
477
|
|
|
|
478
|
|
|
class SeriesSerializer(Serializer): |
479
|
|
|
OBJ_CLASS = series_type |
480
|
|
|
|
481
|
|
|
def encode(self, obj): |
482
|
|
|
return obj.to_json() |
483
|
|
|
|
484
|
|
|
def decode(self, s): |
485
|
|
|
return opt.pandas.read_json(s, typ='series') |
486
|
|
|
|
487
|
|
|
class FileSerializer(Serializer): |
488
|
|
|
OBJ_CLASS = BufferedReaderWrapper |
489
|
|
|
|
490
|
|
|
def __init__(self, fs): |
491
|
|
|
self.fs = fs |
492
|
|
|
|
493
|
|
|
def encode(self, obj): |
494
|
|
|
address = self.fs.put(obj) |
495
|
|
|
return json.dumps(address.id) |
496
|
|
|
|
497
|
|
|
def decode(self, s): |
498
|
|
|
id_ = json.loads(s) |
499
|
|
|
file_reader = self.fs.open(id_) |
500
|
|
|
file_reader = BufferedReaderWrapper(file_reader) |
501
|
|
|
file_reader.hash = id_ |
502
|
|
|
return file_reader |
503
|
|
|
|