Completed
Push — master ( 50b00b...d515a4 )
by Klaus
11s
created

sacred/observers/sql.py (4 issues)

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
import hashlib
5
import json
6
import os
7
from threading import Lock
8
9
import sacred.optional as opt
10
from sacred.commandline_options import CommandLineOption
11
from sacred.dependencies import get_digest
12
from sacred.observers.base import RunObserver
13
from sacred.serializer import flatten, restore
14
15
DEFAULT_SQL_PRIORITY = 40
16
17
18
# ############################# Observer #################################### #
19
20
class SqlObserver(RunObserver):
21
    @classmethod
22
    def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY):
23
        engine = sa.create_engine(url, echo=echo)
24
        session_factory = sessionmaker(bind=engine)
25
        # make session thread-local to avoid problems with sqlite (see #275)
26
        session = scoped_session(session_factory)
27
        return cls(engine, session, priority)
28
29
    def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY):
30
        self.engine = engine
31
        self.session = session
32
        self.priority = priority
33 View Code Duplication
        self.run = None
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
34
        self.lock = Lock()
35
36
    def started_event(self, ex_info, command, host_info, start_time, config,
37
                      meta_info, _id):
38
        Base.metadata.create_all(self.engine)
39
        sql_exp = Experiment.get_or_create(ex_info, self.session)
40
        sql_host = Host.get_or_create(host_info, self.session)
41
        if _id is None:
42
            i = self.session.query(Run).order_by(Run.id.desc()).first()
43
            _id = 0 if i is None else i.id + 1
44
45
        self.run = Run(run_id=str(_id),
46
                       start_time=start_time,
47
                       config=json.dumps(flatten(config)),
48
                       command=command,
49
                       priority=meta_info.get('priority', 0),
50
                       comment=meta_info.get('comment', ''),
51
                       experiment=sql_exp,
52
                       host=sql_host,
53
                       status='RUNNING')
54
        self.session.add(self.run)
55 View Code Duplication
        self.save()
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
56
        return _id or self.run.run_id
57
58
    def queued_event(self, ex_info, command, host_info, queue_time, config,
59
                     meta_info, _id):
60
61
        Base.metadata.create_all(self.engine)
62
        sql_exp = Experiment.get_or_create(ex_info, self.session)
63
        sql_host = Host.get_or_create(host_info, self.session)
64
        if _id is None:
65
            i = self.session.query(Run).order_by(Run.id.desc()).first()
66
            _id = 0 if i is None else i.id + 1
67
68
        self.run = Run(run_id=str(_id),
69
                       config=json.dumps(flatten(config)),
70
                       command=command,
71
                       priority=meta_info.get('priority', 0),
72
                       comment=meta_info.get('comment', ''),
73
                       experiment=sql_exp,
74
                       host=sql_host,
75
                       status='QUEUED')
76
        self.session.add(self.run)
77
        self.save()
78
        return _id or self.run.run_id
79
80
    def heartbeat_event(self, info, captured_out, beat_time, result):
81
        self.run.info = json.dumps(flatten(info))
82
        self.run.captured_out = captured_out
83
        self.run.heartbeat = beat_time
84
        self.run.result = result
85
        self.save()
86
87
    def completed_event(self, stop_time, result):
88
        self.run.stop_time = stop_time
89
        self.run.result = result
90
        self.run.status = 'COMPLETED'
91
        self.save()
92
93
    def interrupted_event(self, interrupt_time, status):
94
        self.run.stop_time = interrupt_time
95
        self.run.status = status
96
        self.save()
97
98
    def failed_event(self, fail_time, fail_trace):
99
        self.run.stop_time = fail_time
100
        self.run.fail_trace = '\n'.join(fail_trace)
101
        self.run.status = 'FAILED'
102
        self.save()
103
104
    def resource_event(self, filename):
105
        res = Resource.get_or_create(filename, self.session)
106
        self.run.resources.append(res)
107
        self.save()
108
109
    def artifact_event(self, name, filename):
110
        a = Artifact.create(name, filename)
111
        self.run.artifacts.append(a)
112
        self.save()
113
114
    def save(self):
115
        with self.lock:
116
            self.session.commit()
117
118
    def query(self, _id):
119
        run = self.session.query(Run).filter_by(id=_id).first()
120
        return run.to_json()
121
122
    def __eq__(self, other):
123
        if isinstance(other, SqlObserver):
124
            # fixme: this will probably fail to detect two equivalent engines
125
            return (self.engine == other.engine and
126
                    self.session == other.session)
127
        return False
128
129
    def __ne__(self, other):
130
        return not self.__eq__(other)
131
132
133
# ######################## Commandline Option ############################### #
134
135
class SqlOption(CommandLineOption):
136
    """Add a SQL Observer to the experiment."""
137
    __depends_on__ = 'sqlalchemy'
138
139
    arg = 'DB_URL'
140
    arg_description = \
141
        "The typical form is: dialect://username:password@host:port/database"
142
143
    @classmethod
144
    def apply(cls, args, run):
145
        run.observers.append(SqlObserver.create(args))
146
147
148
# ################################ ORM ###################################### #
149
if opt.has_sqlalchemy:  # noqa
150
    import sqlalchemy as sa
151
    from sqlalchemy.ext.declarative import declarative_base
152
    from sqlalchemy.orm import sessionmaker, scoped_session
153
154
    Base = declarative_base()
155
156
    class Source(Base):
157
        __tablename__ = 'source'
158
159
        @classmethod
160
        def get_or_create(cls, filename, md5sum, session, basedir):
161
            instance = session.query(cls).filter_by(filename=filename,
162
                                                    md5sum=md5sum).first()
163
            if instance:
164
                return instance
165
            full_path = os.path.join(basedir, filename)
166
            md5sum_ = get_digest(full_path)
167
            assert md5sum_ == md5sum, 'found md5 mismatch for {}: {} != {}'\
168
                .format(filename, md5sum, md5sum_)
169
            with open(full_path, 'r') as f:
170
                return cls(filename=filename, md5sum=md5sum, content=f.read())
171
172
        source_id = sa.Column(sa.Integer, primary_key=True)
173
        filename = sa.Column(sa.String(256))
174
        md5sum = sa.Column(sa.String(32))
175
        content = sa.Column(sa.Text)
176
177
        def to_json(self):
178
            return {'filename': self.filename,
179
                    'md5sum': self.md5sum}
180
181
    class Dependency(Base):
182
        __tablename__ = 'dependency'
183
184
        @classmethod
185
        def get_or_create(cls, dep, session):
186
            name, _, version = dep.partition('==')
187
            instance = session.query(cls).filter_by(name=name,
188
                                                    version=version).first()
189 View Code Duplication
            if instance:
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
190
                return instance
191
            return cls(name=name, version=version)
192
193
        dependency_id = sa.Column(sa.Integer, primary_key=True)
194
        name = sa.Column(sa.String(32))
195
        version = sa.Column(sa.String(16))
196
197
        def to_json(self):
198
            return "{}=={}".format(self.name, self.version)
199
200
    class Artifact(Base):
201
        __tablename__ = 'artifact'
202
203
        @classmethod
204
        def create(cls, name, filename):
205
            with open(filename, 'rb') as f:
206
                return cls(filename=name, content=f.read())
207
208 View Code Duplication
        artifact_id = sa.Column(sa.Integer, primary_key=True)
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
209
        filename = sa.Column(sa.String(64))
210
        content = sa.Column(sa.LargeBinary)
211
212
        run_id = sa.Column(sa.String(24), sa.ForeignKey('run.run_id'))
213
        run = sa.orm.relationship("Run", backref=sa.orm.backref('artifacts'))
214
215
        def to_json(self):
216
            return {'_id': self.artifact_id,
217
                    'filename': self.filename}
218
219
    class Resource(Base):
220
        __tablename__ = 'resource'
221
222
        @classmethod
223
        def get_or_create(cls, filename, session):
224
            md5sum = get_digest(filename)
225
            instance = session.query(cls).filter_by(filename=filename,
226
                                                    md5sum=md5sum).first()
227
            if instance:
228
                return instance
229
            with open(filename, 'rb') as f:
230
                return cls(filename=filename, md5sum=md5sum, content=f.read())
231
232
        resource_id = sa.Column(sa.Integer, primary_key=True)
233
        filename = sa.Column(sa.String(256))
234
        md5sum = sa.Column(sa.String(32))
235
        content = sa.Column(sa.LargeBinary)
236
237
        def to_json(self):
238
            return {'filename': self.filename,
239
                    'md5sum': self.md5sum}
240
241
    class Host(Base):
242
        __tablename__ = 'host'
243
244
        @classmethod
245
        def get_or_create(cls, host_info, session):
246
            h = dict(
247
                hostname=host_info['hostname'],
248
                cpu=host_info['cpu'],
249
                os=host_info['os'][0],
250
                os_info=host_info['os'][1],
251
                python_version=host_info['python_version']
252
            )
253
254
            return session.query(cls).filter_by(**h).first() or cls(**h)
255
256
        host_id = sa.Column(sa.Integer, primary_key=True)
257
        cpu = sa.Column(sa.String(64))
258
        hostname = sa.Column(sa.String(64))
259
        os = sa.Column(sa.String(16))
260
        os_info = sa.Column(sa.String(64))
261
        python_version = sa.Column(sa.String(16))
262
263
        def to_json(self):
264
            return {'cpu': self.cpu,
265
                    'hostname': self.hostname,
266
                    'os': [self.os, self.os_info],
267
                    'python_version': self.python_version}
268
269
    experiment_source_association = sa.Table(
270
        'experiments_sources', Base.metadata,
271
        sa.Column('experiment_id', sa.Integer,
272
                  sa.ForeignKey('experiment.experiment_id')),
273
        sa.Column('source_id', sa.Integer, sa.ForeignKey('source.source_id'))
274
    )
275
276
    experiment_dependency_association = sa.Table(
277
        'experiments_dependencies', Base.metadata,
278
        sa.Column('experiment_id', sa.Integer,
279
                  sa.ForeignKey('experiment.experiment_id')),
280
        sa.Column('dependency_id', sa.Integer,
281
                  sa.ForeignKey('dependency.dependency_id'))
282
    )
283
284
    class Experiment(Base):
285
        __tablename__ = 'experiment'
286
287
        @classmethod
288
        def get_or_create(cls, ex_info, session):
289
            name = ex_info['name']
290
            # Compute a MD5sum of the ex_info to determine its uniqueness
291
            h = hashlib.md5()
292
            h.update(json.dumps(ex_info).encode())
293
            md5 = h.hexdigest()
294
            instance = session.query(cls).filter_by(name=name,
295
                                                    md5sum=md5).first()
296
            if instance:
297
                return instance
298
299
            dependencies = [Dependency.get_or_create(d, session)
300
                            for d in ex_info['dependencies']]
301
            sources = [Source.get_or_create(s, md5sum, session,
302
                                            ex_info['base_dir'])
303
                       for s, md5sum in ex_info['sources']]
304
305
            return cls(name=name, dependencies=dependencies, sources=sources,
306
                       md5sum=md5, base_dir=ex_info['base_dir'])
307
308
        experiment_id = sa.Column(sa.Integer, primary_key=True)
309
        name = sa.Column(sa.String(32))
310
        md5sum = sa.Column(sa.String(32))
311
        base_dir = sa.Column(sa.String(64))
312
        sources = sa.orm.relationship("Source",
313
                                      secondary=experiment_source_association,
314
                                      backref="experiments")
315
        dependencies = sa.orm.relationship(
316
            "Dependency",
317
            secondary=experiment_dependency_association,
318
            backref="experiments")
319
320
        def to_json(self):
321
            return {'name': self.name,
322
                    'base_dir': self.base_dir,
323
                    'sources': [s.to_json() for s in self.sources],
324
                    'dependencies': [d.to_json() for d in self.dependencies]}
325
326
    run_resource_association = sa.Table(
327
        'runs_resources', Base.metadata,
328
        sa.Column('run_id', sa.String(24), sa.ForeignKey('run.run_id')),
329
        sa.Column('resource_id', sa.Integer,
330
                  sa.ForeignKey('resource.resource_id'))
331
    )
332
333
    class Run(Base):
334
        __tablename__ = 'run'
335
        id = sa.Column(sa.Integer, primary_key=True)
336
337
        run_id = sa.Column(sa.String(24), unique=True)
338
339
        command = sa.Column(sa.String(64))
340
341
        # times
342
        start_time = sa.Column(sa.DateTime)
343
        heartbeat = sa.Column(sa.DateTime)
344
        stop_time = sa.Column(sa.DateTime)
345
        queue_time = sa.Column(sa.DateTime)
346
347
        # meta info
348
        priority = sa.Column(sa.Float)
349
        comment = sa.Column(sa.Text)
350
351
        fail_trace = sa.Column(sa.Text)
352
353
        # Captured out
354
        # TODO: move to separate table?
355
        captured_out = sa.Column(sa.Text)
356
357
        # Configuration & info
358
        # TODO: switch type to json if possible
359
        config = sa.Column(sa.Text)
360
        info = sa.Column(sa.Text)
361
362
        status = sa.Column(sa.Enum("RUNNING", "COMPLETED", "INTERRUPTED",
363
                                   "TIMEOUT", "FAILED", name="status_enum"))
364
365
        host_id = sa.Column(sa.Integer, sa.ForeignKey('host.host_id'))
366
        host = sa.orm.relationship("Host", backref=sa.orm.backref('runs'))
367
368
        experiment_id = sa.Column(sa.Integer,
369
                                  sa.ForeignKey('experiment.experiment_id'))
370
        experiment = sa.orm.relationship("Experiment",
371
                                         backref=sa.orm.backref('runs'))
372
373
        # artifacts = backref
374
        resources = sa.orm.relationship("Resource",
375
                                        secondary=run_resource_association,
376
                                        backref="runs")
377
378
        result = sa.Column(sa.Float)
379
380
        def to_json(self):
381
            return {
382
                '_id': self.run_id,
383
                'command': self.command,
384
                'start_time': self.start_time,
385
                'heartbeat': self.heartbeat,
386
                'stop_time': self.stop_time,
387
                'queue_time': self.queue_time,
388
                'status': self.status,
389
                'result': self.result,
390
                'meta': {
391
                    'comment': self.comment,
392
                    'priority': self.priority},
393
                'resources': [r.to_json() for r in self.resources],
394
                'artifacts': [a.to_json() for a in self.artifacts],
395
                'host': self.host.to_json(),
396
                'experiment': self.experiment.to_json(),
397
                'config': restore(json.loads(self.config)),
398
                'captured_out': self.captured_out,
399
                'fail_trace': self.fail_trace,
400
            }
401