Completed
Push — master ( 6c3661...416f46 )
by Klaus
34s
created

tests/test_observers/test_sql_observer.py (5 issues)

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import datetime
6
import hashlib
7
import os
8
9
import pytest
10
import tempfile
11
from sacred.serializer import json
12
13
sqlalchemy = pytest.importorskip("sqlalchemy")
14
15
from sacred.observers.sql import (SqlObserver, Host, Experiment, Run, Source,
16
                                  Resource)
17
18
T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0)
19
T2 = datetime.datetime(1999, 5, 5, 5, 5, 5, 5)
20
21
22
@pytest.fixture
23
def engine(request):
24
    """Engine configuration."""
25
    url = request.config.getoption("--sqlalchemy-connect-url")
26
    from sqlalchemy.engine import create_engine
27
    engine = create_engine(url)
28
    yield engine
29
    engine.dispose()
30
31
32
@pytest.fixture
33
def session(engine):
34
    from sqlalchemy.orm import sessionmaker, scoped_session
35
    connection = engine.connect()
36
    trans = connection.begin()
37
    session_factory = sessionmaker(bind=engine)
38
    # make session thread-local to avoid problems with sqlite (see #275)
39
    session = scoped_session(session_factory)
40
    yield session
41
    session.close()
42
    trans.rollback()
43
    connection.close()
44
45
46
@pytest.fixture
47
def sql_obs(session, engine):
48
    return SqlObserver(engine, session)
49
50
51
@pytest.fixture
52
def sample_run():
53
    exp = {'name': 'test_exp', 'sources': [], 'dependencies': [],
54
           'base_dir': '/tmp'}
55
    host = {'hostname': 'test_host', 'cpu': 'Intel', 'os': ['Linux', 'Ubuntu'],
56
            'python_version': '3.4'}
57
    config = {'config': 'True', 'foo': 'bar', 'answer': 42}
58
    command = 'run'
59
    meta_info = {'comment': 'test run'}
60
    return {
61
        '_id': 'FEDCBA9876543210',
62
        'ex_info': exp,
63
        'command': command,
64
        'host_info': host,
65
        'start_time': T1,
66
        'config': config,
67
        'meta_info': meta_info,
68
    }
69 View Code Duplication
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
70
71
@pytest.fixture
72
def tmpfile():
73
    # NOTE: instead of using a with block and delete=True we are creating and
74
    # manually deleting the file, such that we can close it before running the
75
    # tests. This is necessary since on Windows we can not open the same file
76
    # twice, so for the FileStorageObserver to read it, we need to close it.
77
    f = tempfile.NamedTemporaryFile(suffix='.py', delete=False)
78
79
    f.content = 'import sacred\n'
80
    f.write(f.content.encode())
81
    f.flush()
82
    f.seek(0)
83
    f.md5sum = hashlib.md5(f.read()).hexdigest()
84
85
    f.close()
86
87
    yield f
88
89
    os.remove(f.name)
90
91
92
def test_sql_observer_started_event_creates_run(sql_obs, sample_run, session):
93
    sample_run['_id'] = None
94
    _id = sql_obs.started_event(**sample_run)
95
    assert _id is not None
96
    assert session.query(Run).count() == 1
97
    assert session.query(Host).count() == 1
98
    assert session.query(Experiment).count() == 1
99
    run = session.query(Run).first()
100
    assert run.to_json() == {
101
            '_id': _id,
102
            'command': sample_run['command'],
103
            'start_time': sample_run['start_time'],
104
            'heartbeat': None,
105
            'stop_time': None,
106
            'queue_time': None,
107
            'status': 'RUNNING',
108
            'result': None,
109
            'meta': {
110
                'comment': sample_run['meta_info']['comment'],
111
                'priority': 0.0},
112
            'resources': [],
113
            'artifacts': [],
114
            'host': sample_run['host_info'],
115
            'experiment': sample_run['ex_info'],
116
            'config': sample_run['config'],
117
            'captured_out': None,
118
            'fail_trace': None,
119
        }
120
121
122
def test_sql_observer_started_event_uses_given_id(sql_obs, sample_run, session):
123
    _id = sql_obs.started_event(**sample_run)
124
    assert _id == sample_run['_id']
125
    assert session.query(Run).count() == 1
126
    db_run = session.query(Run).first()
127
    assert db_run.run_id == sample_run['_id']
128
129
130
def test_fs_observer_started_event_saves_source(sql_obs, sample_run, session,
131
                                                tmpfile):
132
    sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]
133
134
    sql_obs.started_event(**sample_run)
135
136
    assert session.query(Run).count() == 1
137
    db_run = session.query(Run).first()
138
    assert session.query(Source).count() == 1
139
    assert len(db_run.experiment.sources) == 1
140
    source = db_run.experiment.sources[0]
141
    assert source.filename == tmpfile.name
142
    assert source.content == 'import sacred\n'
143
    assert source.md5sum == tmpfile.md5sum
144
145
146
def test_sql_observer_heartbeat_event_updates_run(sql_obs, sample_run, session):
147
    sql_obs.started_event(**sample_run)
148
149
    info = {'my_info': [1, 2, 3], 'nr': 7}
150
    outp = 'some output'
151
    sql_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2,
152
                            result=23.5)
153
154
    assert session.query(Run).count() == 1
155
    db_run = session.query(Run).first()
156
    assert db_run.heartbeat == T2
157
    assert db_run.result == 23.5
158
    assert json.decode(db_run.info) == info
159
    assert db_run.captured_out == outp
160
161
162
def test_sql_observer_completed_event_updates_run(sql_obs, sample_run, session):
163
    sql_obs.started_event(**sample_run)
164
    sql_obs.completed_event(stop_time=T2, result=42)
165
166
    assert session.query(Run).count() == 1
167
    db_run = session.query(Run).first()
168
169
    assert db_run.stop_time == T2
170
    assert db_run.result == 42
171
    assert db_run.status == 'COMPLETED'
172
173
174
def test_sql_observer_interrupted_event_updates_run(sql_obs, sample_run, session):
175
    sql_obs.started_event(**sample_run)
176
    sql_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED')
177
178
    assert session.query(Run).count() == 1
179
    db_run = session.query(Run).first()
180
181
    assert db_run.stop_time == T2
182
    assert db_run.status == 'INTERRUPTED'
183
184
185
def test_sql_observer_failed_event_updates_run(sql_obs, sample_run, session):
186
    sql_obs.started_event(**sample_run)
187
    fail_trace = ["lots of errors and", "so", "on..."]
188
    sql_obs.failed_event(fail_time=T2, fail_trace=fail_trace)
189
190
    assert session.query(Run).count() == 1
191
    db_run = session.query(Run).first()
192
193
    assert db_run.stop_time == T2
194 View Code Duplication
    assert db_run.status == 'FAILED'
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
195
    assert db_run.fail_trace == "lots of errors and\nso\non..."
196
197
198
def test_sql_observer_artifact_event(sql_obs, sample_run, session, tmpfile):
199
    sql_obs.started_event(**sample_run)
200
201
    sql_obs.artifact_event('my_artifact.py', tmpfile.name)
202
203
    assert session.query(Run).count() == 1
204
    db_run = session.query(Run).first()
205
206
    assert len(db_run.artifacts) == 1
207
    artifact = db_run.artifacts[0]
208
209 View Code Duplication
    assert artifact.filename == 'my_artifact.py'
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
210
    assert artifact.content.decode() == tmpfile.content
211
212
213
def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile):
214
    sql_obs.started_event(**sample_run)
215
216
    sql_obs.resource_event(tmpfile.name)
217
218
    assert session.query(Run).count() == 1
219
    db_run = session.query(Run).first()
220
221
    assert len(db_run.resources) == 1
222
    res = db_run.resources[0]
223
    assert res.filename == tmpfile.name
224 View Code Duplication
    assert res.md5sum == tmpfile.md5sum
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
225
    assert res.content.decode() == tmpfile.content
226
227
228
def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile):
229
    sql_obs2 = SqlObserver(sql_obs.engine, session)
230
    sample_run['_id'] = None
231
    sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]
232
233
    sql_obs.started_event(**sample_run)
234
    sql_obs2.started_event(**sample_run)
235
236 View Code Duplication
    assert session.query(Run).count() == 2
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
237
    assert session.query(Source).count() == 1
238
239
240
def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile):
241
    sql_obs2 = SqlObserver(sql_obs.engine, session)
242
    sample_run['_id'] = None
243
    sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]
244
245
    sql_obs.started_event(**sample_run)
246
    sql_obs2.started_event(**sample_run)
247
248
    sql_obs.resource_event(tmpfile.name)
249
    sql_obs2.resource_event(tmpfile.name)
250
251
    assert session.query(Run).count() == 2
252
    assert session.query(Resource).count() == 1
253
254
255
def test_sql_observer_equality(sql_obs, engine, session):
256
    sql_obs2 = SqlObserver(engine, session)
257
    assert sql_obs == sql_obs2
258
259
    assert not sql_obs != sql_obs2
260
261
    assert not sql_obs == 'foo'
262
    assert sql_obs != 'foo'
263