Completed
Push — master ( dbc38f...56accc )
by Klaus
01:34
created

Run.to_json()   A

Complexity

Conditions 3

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
dl 0
loc 20
rs 9.4285
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
import hashlib
5
import json
6
7
import sqlalchemy as sa
8
from sqlalchemy.ext.declarative import declarative_base
9
from sqlalchemy.orm import sessionmaker
10
11
from sacred.commandline_options import CommandLineOption
12
from sacred.dependencies import get_digest
13
from sacred.observers.base import RunObserver
14
from sacred.serializer import flatten, restore
15
16
# ################################ ORM ###################################### #
17
Base = declarative_base()
18
19
20 View Code Duplication
class Source(Base):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
21
    __tablename__ = 'source'
22
23
    @classmethod
24
    def get_or_create(cls, filename, md5sum, session):
25
        instance = session.query(cls).filter_by(filename=filename,
26
                                                md5sum=md5sum).first()
27
        if instance:
28
            return instance
29
        md5sum_ = get_digest(filename)
30
        assert md5sum_ == md5sum, 'Weird: found md5 mismatch for {}: {} != {}'\
31
            .format(filename, md5sum, md5sum_)
32
        with open(filename, 'r') as f:
33
            return cls(filename=filename, md5sum=md5sum, content=f.read())
34
35
    source_id = sa.Column(sa.Integer, primary_key=True)
36
    filename = sa.Column(sa.String(256))
37
    md5sum = sa.Column(sa.String(32))
38
    content = sa.Column(sa.Text)
39
40
    def to_json(self):
41
        return {'filename': self.filename,
42
                'md5sum': self.md5sum}
43
44
45
class Dependency(Base):
46
    __tablename__ = 'dependency'
47
48
    @classmethod
49
    def get_or_create(cls, dep, session):
50
        name, _, version = dep.partition('==')
51
        instance = session.query(cls).filter_by(name=name,
52
                                                version=version).first()
53
        if instance:
54
            return instance
55
        return cls(name=name, version=version)
56
57
    dependency_id = sa.Column(sa.Integer, primary_key=True)
58
    name = sa.Column(sa.String(32))
59
    version = sa.Column(sa.String(16))
60
61
    def to_json(self):
62
        return "{}=={}".format(self.name, self.version)
63
64
65
class Artifact(Base):
66
    __tablename__ = 'artifact'
67
68
    @classmethod
69
    def create(cls, name, filename):
70
        with open(filename, 'rb') as f:
71
            return cls(filename=name, content=f.read())
72
73
    artifact_id = sa.Column(sa.Integer, primary_key=True)
74
    filename = sa.Column(sa.String(64))
75
    content = sa.Column(sa.LargeBinary)
76
77
    run_id = sa.Column(sa.String(24), sa.ForeignKey('run.run_id'))
78
    run = sa.orm.relationship("Run", backref=sa.orm.backref('artifacts'))
79
80
    def to_json(self):
81
        return {'_id': self.artifact_id,
82
                'filename': self.filename}
83
84
85 View Code Duplication
class Resource(Base):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
86
    __tablename__ = 'resource'
87
88
    @classmethod
89
    def get_or_create(cls, filename, session):
90
        md5sum = get_digest(filename)
91
        instance = session.query(cls).filter_by(filename=filename,
92
                                                md5sum=md5sum).first()
93
        if instance:
94
            return instance
95
        with open(filename, 'rb') as f:
96
            return cls(filename=filename, md5sum=md5sum, content=f.read())
97
98
    resource_id = sa.Column(sa.Integer, primary_key=True)
99
    filename = sa.Column(sa.String(256))
100
    md5sum = sa.Column(sa.String(32))
101
    content = sa.Column(sa.LargeBinary)
102
103
    def to_json(self):
104
        return {'filename': self.filename,
105
                'md5sum': self.md5sum}
106
107
108
class Host(Base):
109
    __tablename__ = 'host'
110
111
    @classmethod
112
    def get_or_create(cls, host_info, session):
113
        h = dict(
114
            hostname=host_info['hostname'],
115
            cpu=host_info['cpu'],
116
            os=host_info['os'][0],
117
            os_info=host_info['os'][1],
118
            python_version=host_info['python_version']
119
        )
120
121
        return session.query(cls).filter_by(**h).first() or cls(**h)
122
123
    host_id = sa.Column(sa.Integer, primary_key=True)
124
    cpu = sa.Column(sa.String(64))
125
    hostname = sa.Column(sa.String(64))
126
    os = sa.Column(sa.String(16))
127
    os_info = sa.Column(sa.String(64))
128
    python_version = sa.Column(sa.String(16))
129
130
    def to_json(self):
131
        return {'cpu': self.cpu,
132
                'hostname': self.hostname,
133
                'os': [self.os, self.os_info],
134
                'python_version': self.python_version}
135
136
137
experiment_source_association = sa.Table(
138
    'experiments_sources', Base.metadata,
139
    sa.Column('experiment_id', sa.Integer,
140
              sa.ForeignKey('experiment.experiment_id')),
141
    sa.Column('source_id', sa.Integer, sa.ForeignKey('source.source_id'))
142
)
143
144
experiment_dependency_association = sa.Table(
145
    'experiments_dependencies', Base.metadata,
146
    sa.Column('experiment_id', sa.Integer,
147
              sa.ForeignKey('experiment.experiment_id')),
148
    sa.Column('dependency_id', sa.Integer,
149
              sa.ForeignKey('dependency.dependency_id'))
150
)
151
152
153
class Experiment(Base):
154
    __tablename__ = 'experiment'
155
156
    @classmethod
157
    def get_or_create(cls, ex_info, session):
158
        name = ex_info['name']
159
        # Compute a MD5sum of the ex_info to determine its uniqueness
160
        h = hashlib.md5()
161
        h.update(json.dumps(ex_info).encode())
162
        md5 = h.hexdigest()
163
        instance = session.query(cls).filter_by(name=name, md5sum=md5).first()
164
        if instance:
165
            return instance
166
167
        dependencies = [Dependency.get_or_create(d, session)
168
                        for d in ex_info['dependencies']]
169
        sources = [Source.get_or_create(s, md5sum, session)
170
                   for s, md5sum in ex_info['sources']]
171
172
        return cls(name=name, dependencies=dependencies, sources=sources,
173
                   md5sum=md5, base_dir=ex_info['base_dir'])
174
175
    experiment_id = sa.Column(sa.Integer, primary_key=True)
176
    name = sa.Column(sa.String(32))
177
    md5sum = sa.Column(sa.String(32))
178
    base_dir = sa.Column(sa.String(64))
179
    sources = sa.orm.relationship("Source",
180
                                  secondary=experiment_source_association,
181
                                  backref="experiments")
182
    dependencies = sa.orm.relationship(
183
        "Dependency",
184
        secondary=experiment_dependency_association,
185
        backref="experiments")
186
187
    def to_json(self):
188
        return {'name': self.name,
189
                'base_dir': self.base_dir,
190
                'sources': [s.to_json() for s in self.sources],
191
                'dependencies': [d.to_json() for d in self.dependencies]}
192
193
194
run_resource_association = sa.Table(
195
    'runs_resources', Base.metadata,
196
    sa.Column('run_id', sa.String(24), sa.ForeignKey('run.run_id')),
197
    sa.Column('resource_id', sa.Integer, sa.ForeignKey('resource.resource_id'))
198
)
199
200
201
class Run(Base):
202
    __tablename__ = 'run'
203
204
    run_id = sa.Column(sa.String(24), primary_key=True)
205
206
    command = sa.Column(sa.String(64))
207
208
    # times
209
    start_time = sa.Column(sa.DateTime)
210
    heartbeat = sa.Column(sa.DateTime)
211
    stop_time = sa.Column(sa.DateTime)
212
    queue_time = sa.Column(sa.DateTime)
213
214
    # meta info
215
    priority = sa.Column(sa.Float)
216
    comment = sa.Column(sa.Text)
217
218
    fail_trace = sa.Column(sa.Text)
219
220
    # Captured out
221
    # TODO: move to separate table?
222
    captured_out = sa.Column(sa.Text)
223
224
    # Configuration & info
225
    # TODO: switch type to json if possible
226
    config = sa.Column(sa.Text)
227
    info = sa.Column(sa.Text)
228
229
    status = sa.Column(sa.Enum("RUNNING", "COMPLETED", "INTERRUPTED",
230
                               "TIMEOUT", "FAILED"))
231
232
    host_id = sa.Column(sa.Integer, sa.ForeignKey('host.host_id'))
233
    host = sa.orm.relationship("Host", backref=sa.orm.backref('runs'))
234
235
    experiment_id = sa.Column(sa.Integer,
236
                              sa.ForeignKey('experiment.experiment_id'))
237
    experiment = sa.orm.relationship("Experiment",
238
                                     backref=sa.orm.backref('runs'))
239
240
    # artifacts = backref
241
    resources = sa.orm.relationship("Resource",
242
                                    secondary=run_resource_association,
243
                                    backref="runs")
244
245
    result = sa.Column(sa.Float)
246
247
    def to_json(self):
248
        return {
249
            '_id': self.run_id,
250
            'command': self.command,
251
            'start_time': self.start_time,
252
            'heartbeat': self.heartbeat,
253
            'stop_time': self.stop_time,
254
            'queue_time': self.queue_time,
255
            'status': self.status,
256
            'result': self.result,
257
            'meta': {
258
                'comment': self.comment,
259
                'priority': self.priority},
260
            'resources': [r.to_json() for r in self.resources],
261
            'artifacts': [a.to_json() for a in self.artifacts],
262
            'host': self.host.to_json(),
263
            'experiment': self.experiment.to_json(),
264
            'config': restore(json.loads(self.config)),
265
            'captured_out': self.captured_out,
266
            'fail_trace': self.fail_trace,
267
        }
268
269
270
# ############################# Observer #################################### #
271
272
class SqlObserver(RunObserver):
273
    @classmethod
274
    def create(cls, url, echo=False):
275
        engine = sa.create_engine(url, echo=echo)
276
        return cls(engine, sessionmaker(bind=engine)())
277
278
    def __init__(self, engine, session):
279
        self.engine = engine
280
        self.session = session
281
        self.run = None
282
283 View Code Duplication
    def started_event(self, ex_info, command, host_info, start_time, config,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
284
                      meta_info, _id):
285
        Base.metadata.create_all(self.engine)
286
        sql_exp = Experiment.get_or_create(ex_info, self.session)
287
        sql_host = Host.get_or_create(host_info, self.session)
288
        if _id is None:
289
            i = self.session.query(Run).order_by(Run.run_id.desc()).first()
290
            _id = '0' if i is None else str(int(i.run_id) + 1)
291
292
        self.run = Run(run_id=_id,
293
                       start_time=start_time,
294
                       config=json.dumps(flatten(config)),
295
                       command=command,
296
                       priority=meta_info.get('priority', 0),
297
                       comment=meta_info.get('comment', ''),
298
                       experiment=sql_exp,
299
                       host=sql_host,
300
                       status='RUNNING')
301
        self.session.add(self.run)
302
        self.session.commit()
303
        return _id or self.run.run_id
304
305 View Code Duplication
    def queued_event(self, ex_info, command, queue_time, config, meta_info,
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
306
                     _id):
307
308
        Base.metadata.create_all(self.engine)
309
        sql_exp = Experiment.get_or_create(ex_info, self.session)
310
        if _id is None:
311
            i = self.session.query(Run).order_by(Run.run_id.desc()).first()
312
            _id = '0' if i is None else str(int(i.id) + 1)
313
314
        self.run = Run(run_id=_id,
315
                       config=json.dumps(flatten(config)),
316
                       command=command,
317
                       priority=meta_info.get('priority', 0),
318
                       comment=meta_info.get('comment', ''),
319
                       experiment=sql_exp,
320
                       status='QUEUED')
321
        self.session.add(self.run)
322
        self.session.commit()
323
        return _id or self.run.run_id
324
325
    def heartbeat_event(self, info, captured_out, beat_time):
326
        self.run.info = json.dumps(flatten(info))
327
        self.run.captured_out = captured_out
328
        self.run.heartbeat = beat_time
329
        self.session.commit()
330
331
    def completed_event(self, stop_time, result):
332
        self.run.stop_time = stop_time
333
        self.run.result = result
334
        self.run.status = 'COMPLETED'
335
        self.session.commit()
336
337
    def interrupted_event(self, interrupt_time, status):
338
        self.run.stop_time = interrupt_time
339
        self.run.status = status
340
        self.session.commit()
341
342
    def failed_event(self, fail_time, fail_trace):
343
        self.run.stop_time = fail_time
344
        self.run.fail_trace = '\n'.join(fail_trace)
345
        self.run.status = 'FAILED'
346
        self.session.commit()
347
348
    def resource_event(self, filename):
349
        res = Resource.get_or_create(filename, self.session)
350
        self.run.resources.append(res)
351
        self.session.commit()
352
353
    def artifact_event(self, name, filename):
354
        a = Artifact.create(name, filename)
355
        self.run.artifacts.append(a)
356
        self.session.commit()
357
358
    def query(self, _id):
359
        run = self.session.query(Run).filter_by(id=_id).first()
360
        return run.to_json()
361
362
    def __eq__(self, other):
363
        if isinstance(other, SqlObserver):
364
            # fixme: this will probably fail to detect two equivalent engines
365
            return (self.engine == other.engine and
366
                    self.session == other.session)
367
        return False
368
369
    def __ne__(self, other):
370
        return not self.__eq__(other)
371
372
373
# ######################## Commandline Option ############################### #
374
375
class SqlOption(CommandLineOption):
376
    """Add a SQL Observer to the experiment."""
377
378
    arg = 'DB_URL'
379
    arg_description = \
380
        "The typical form is: dialect://username:password@host:port/database"
381
382
    @classmethod
383
    def apply(cls, args, run):
384
        run.observers.append(SqlObserver.create(args))
385