1 | #!/usr/bin/env python |
||
2 | # coding=utf-8 |
||
3 | from __future__ import division, print_function, unicode_literals |
||
4 | |||
5 | import datetime |
||
6 | import hashlib |
||
7 | import os |
||
8 | |||
9 | import pytest |
||
10 | import tempfile |
||
11 | from sacred.serializer import json |
||
12 | |||
13 | sqlalchemy = pytest.importorskip("sqlalchemy") |
||
14 | |||
15 | from sacred.observers.sql import (SqlObserver, Host, Experiment, Run, Source, |
||
16 | Resource) |
||
17 | |||
18 | T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0) |
||
19 | T2 = datetime.datetime(1999, 5, 5, 5, 5, 5, 5) |
||
20 | |||
21 | |||
22 | @pytest.fixture |
||
23 | def engine(request): |
||
24 | """Engine configuration.""" |
||
25 | url = request.config.getoption("--sqlalchemy-connect-url") |
||
26 | from sqlalchemy.engine import create_engine |
||
27 | engine = create_engine(url) |
||
28 | yield engine |
||
29 | engine.dispose() |
||
30 | |||
31 | |||
32 | @pytest.fixture |
||
33 | def session(engine): |
||
34 | from sqlalchemy.orm import sessionmaker, scoped_session |
||
35 | connection = engine.connect() |
||
36 | trans = connection.begin() |
||
37 | session_factory = sessionmaker(bind=engine) |
||
38 | # make session thread-local to avoid problems with sqlite (see #275) |
||
39 | session = scoped_session(session_factory) |
||
40 | yield session |
||
41 | session.close() |
||
42 | trans.rollback() |
||
43 | connection.close() |
||
44 | |||
45 | |||
46 | @pytest.fixture |
||
47 | def sql_obs(session, engine): |
||
48 | return SqlObserver(engine, session) |
||
49 | |||
50 | |||
51 | @pytest.fixture |
||
52 | def sample_run(): |
||
53 | exp = {'name': 'test_exp', 'sources': [], 'dependencies': [], |
||
54 | 'base_dir': '/tmp'} |
||
55 | host = {'hostname': 'test_host', 'cpu': 'Intel', 'os': ['Linux', 'Ubuntu'], |
||
56 | 'python_version': '3.4'} |
||
57 | config = {'config': 'True', 'foo': 'bar', 'answer': 42} |
||
58 | command = 'run' |
||
59 | meta_info = {'comment': 'test run'} |
||
60 | return { |
||
61 | '_id': 'FEDCBA9876543210', |
||
62 | 'ex_info': exp, |
||
63 | 'command': command, |
||
64 | 'host_info': host, |
||
65 | 'start_time': T1, |
||
66 | 'config': config, |
||
67 | 'meta_info': meta_info, |
||
68 | } |
||
69 | View Code Duplication | ||
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
70 | |||
71 | @pytest.fixture |
||
72 | def tmpfile(): |
||
73 | # NOTE: instead of using a with block and delete=True we are creating and |
||
74 | # manually deleting the file, such that we can close it before running the |
||
75 | # tests. This is necessary since on Windows we can not open the same file |
||
76 | # twice, so for the FileStorageObserver to read it, we need to close it. |
||
77 | f = tempfile.NamedTemporaryFile(suffix='.py', delete=False) |
||
78 | |||
79 | f.content = 'import sacred\n' |
||
80 | f.write(f.content.encode()) |
||
81 | f.flush() |
||
82 | f.seek(0) |
||
83 | f.md5sum = hashlib.md5(f.read()).hexdigest() |
||
84 | |||
85 | f.close() |
||
86 | |||
87 | yield f |
||
88 | |||
89 | os.remove(f.name) |
||
90 | |||
91 | |||
92 | def test_sql_observer_started_event_creates_run(sql_obs, sample_run, session): |
||
93 | sample_run['_id'] = None |
||
94 | _id = sql_obs.started_event(**sample_run) |
||
95 | assert _id is not None |
||
96 | assert session.query(Run).count() == 1 |
||
97 | assert session.query(Host).count() == 1 |
||
98 | assert session.query(Experiment).count() == 1 |
||
99 | run = session.query(Run).first() |
||
100 | assert run.to_json() == { |
||
101 | '_id': _id, |
||
102 | 'command': sample_run['command'], |
||
103 | 'start_time': sample_run['start_time'], |
||
104 | 'heartbeat': None, |
||
105 | 'stop_time': None, |
||
106 | 'queue_time': None, |
||
107 | 'status': 'RUNNING', |
||
108 | 'result': None, |
||
109 | 'meta': { |
||
110 | 'comment': sample_run['meta_info']['comment'], |
||
111 | 'priority': 0.0}, |
||
112 | 'resources': [], |
||
113 | 'artifacts': [], |
||
114 | 'host': sample_run['host_info'], |
||
115 | 'experiment': sample_run['ex_info'], |
||
116 | 'config': sample_run['config'], |
||
117 | 'captured_out': None, |
||
118 | 'fail_trace': None, |
||
119 | } |
||
120 | |||
121 | |||
122 | def test_sql_observer_started_event_uses_given_id(sql_obs, sample_run, session): |
||
123 | _id = sql_obs.started_event(**sample_run) |
||
124 | assert _id == sample_run['_id'] |
||
125 | assert session.query(Run).count() == 1 |
||
126 | db_run = session.query(Run).first() |
||
127 | assert db_run.run_id == sample_run['_id'] |
||
128 | |||
129 | |||
130 | def test_fs_observer_started_event_saves_source(sql_obs, sample_run, session, |
||
131 | tmpfile): |
||
132 | sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] |
||
133 | |||
134 | sql_obs.started_event(**sample_run) |
||
135 | |||
136 | assert session.query(Run).count() == 1 |
||
137 | db_run = session.query(Run).first() |
||
138 | assert session.query(Source).count() == 1 |
||
139 | assert len(db_run.experiment.sources) == 1 |
||
140 | source = db_run.experiment.sources[0] |
||
141 | assert source.filename == tmpfile.name |
||
142 | assert source.content == 'import sacred\n' |
||
143 | assert source.md5sum == tmpfile.md5sum |
||
144 | |||
145 | |||
146 | def test_sql_observer_heartbeat_event_updates_run(sql_obs, sample_run, session): |
||
147 | sql_obs.started_event(**sample_run) |
||
148 | |||
149 | info = {'my_info': [1, 2, 3], 'nr': 7} |
||
150 | outp = 'some output' |
||
151 | sql_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, |
||
152 | result=23.5) |
||
153 | |||
154 | assert session.query(Run).count() == 1 |
||
155 | db_run = session.query(Run).first() |
||
156 | assert db_run.heartbeat == T2 |
||
157 | assert db_run.result == 23.5 |
||
158 | assert json.decode(db_run.info) == info |
||
159 | assert db_run.captured_out == outp |
||
160 | |||
161 | |||
162 | def test_sql_observer_completed_event_updates_run(sql_obs, sample_run, session): |
||
163 | sql_obs.started_event(**sample_run) |
||
164 | sql_obs.completed_event(stop_time=T2, result=42) |
||
165 | |||
166 | assert session.query(Run).count() == 1 |
||
167 | db_run = session.query(Run).first() |
||
168 | |||
169 | assert db_run.stop_time == T2 |
||
170 | assert db_run.result == 42 |
||
171 | assert db_run.status == 'COMPLETED' |
||
172 | |||
173 | |||
174 | def test_sql_observer_interrupted_event_updates_run(sql_obs, sample_run, session): |
||
175 | sql_obs.started_event(**sample_run) |
||
176 | sql_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED') |
||
177 | |||
178 | assert session.query(Run).count() == 1 |
||
179 | db_run = session.query(Run).first() |
||
180 | |||
181 | assert db_run.stop_time == T2 |
||
182 | assert db_run.status == 'INTERRUPTED' |
||
183 | |||
184 | |||
185 | def test_sql_observer_failed_event_updates_run(sql_obs, sample_run, session): |
||
186 | sql_obs.started_event(**sample_run) |
||
187 | fail_trace = ["lots of errors and", "so", "on..."] |
||
188 | sql_obs.failed_event(fail_time=T2, fail_trace=fail_trace) |
||
189 | |||
190 | assert session.query(Run).count() == 1 |
||
191 | db_run = session.query(Run).first() |
||
192 | |||
193 | assert db_run.stop_time == T2 |
||
194 | View Code Duplication | assert db_run.status == 'FAILED' |
|
0 ignored issues
–
show
|
|||
195 | assert db_run.fail_trace == "lots of errors and\nso\non..." |
||
196 | |||
197 | |||
198 | def test_sql_observer_artifact_event(sql_obs, sample_run, session, tmpfile): |
||
199 | sql_obs.started_event(**sample_run) |
||
200 | |||
201 | sql_obs.artifact_event('my_artifact.py', tmpfile.name) |
||
202 | |||
203 | assert session.query(Run).count() == 1 |
||
204 | db_run = session.query(Run).first() |
||
205 | |||
206 | assert len(db_run.artifacts) == 1 |
||
207 | artifact = db_run.artifacts[0] |
||
208 | |||
209 | View Code Duplication | assert artifact.filename == 'my_artifact.py' |
|
0 ignored issues
–
show
|
|||
210 | assert artifact.content.decode() == tmpfile.content |
||
211 | |||
212 | |||
213 | def test_fs_observer_resource_event(sql_obs, sample_run, session, tmpfile): |
||
214 | sql_obs.started_event(**sample_run) |
||
215 | |||
216 | sql_obs.resource_event(tmpfile.name) |
||
217 | |||
218 | assert session.query(Run).count() == 1 |
||
219 | db_run = session.query(Run).first() |
||
220 | |||
221 | assert len(db_run.resources) == 1 |
||
222 | res = db_run.resources[0] |
||
223 | assert res.filename == tmpfile.name |
||
224 | View Code Duplication | assert res.md5sum == tmpfile.md5sum |
|
0 ignored issues
–
show
|
|||
225 | assert res.content.decode() == tmpfile.content |
||
226 | |||
227 | |||
228 | def test_fs_observer_doesnt_duplicate_sources(sql_obs, sample_run, session, tmpfile): |
||
229 | sql_obs2 = SqlObserver(sql_obs.engine, session) |
||
230 | sample_run['_id'] = None |
||
231 | sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] |
||
232 | |||
233 | sql_obs.started_event(**sample_run) |
||
234 | sql_obs2.started_event(**sample_run) |
||
235 | |||
236 | View Code Duplication | assert session.query(Run).count() == 2 |
|
0 ignored issues
–
show
|
|||
237 | assert session.query(Source).count() == 1 |
||
238 | |||
239 | |||
240 | def test_fs_observer_doesnt_duplicate_resources(sql_obs, sample_run, session, tmpfile): |
||
241 | sql_obs2 = SqlObserver(sql_obs.engine, session) |
||
242 | sample_run['_id'] = None |
||
243 | sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]] |
||
244 | |||
245 | sql_obs.started_event(**sample_run) |
||
246 | sql_obs2.started_event(**sample_run) |
||
247 | |||
248 | sql_obs.resource_event(tmpfile.name) |
||
249 | sql_obs2.resource_event(tmpfile.name) |
||
250 | |||
251 | assert session.query(Run).count() == 2 |
||
252 | assert session.query(Resource).count() == 1 |
||
253 | |||
254 | |||
255 | def test_sql_observer_equality(sql_obs, engine, session): |
||
256 | sql_obs2 = SqlObserver(engine, session) |
||
257 | assert sql_obs == sql_obs2 |
||
258 | |||
259 | assert not sql_obs != sql_obs2 |
||
260 | |||
261 | assert not sql_obs == 'foo' |
||
262 | assert sql_obs != 'foo' |
||
263 |