1 | #!/usr/bin/env python |
||
2 | # coding=utf-8 |
||
3 | from __future__ import division, print_function, unicode_literals |
||
4 | import datetime |
||
5 | import mock |
||
6 | import pytest |
||
7 | |||
8 | from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics |
||
9 | |||
10 | pymongo = pytest.importorskip("pymongo") |
||
11 | mongomock = pytest.importorskip("mongomock") |
||
12 | |||
13 | from sacred.dependencies import get_digest |
||
14 | from sacred.observers.mongo import (MongoObserver, force_bson_encodeable) |
||
15 | |||
16 | T1 = datetime.datetime(1999, 5, 4, 3, 2, 1) |
||
17 | T2 = datetime.datetime(1999, 5, 5, 5, 5, 5) |
||
18 | |||
19 | |||
20 | @pytest.fixture |
||
21 | def mongo_obs(): |
||
22 | db = mongomock.MongoClient().db |
||
23 | runs = db.runs |
||
24 | metrics = db.metrics |
||
25 | fs = mock.MagicMock() |
||
26 | return MongoObserver(runs, fs, metrics_collection=metrics) |
||
27 | |||
28 | |||
29 | @pytest.fixture() |
||
30 | def sample_run(): |
||
31 | exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} |
||
32 | host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} |
||
33 | config = {'config': 'True', 'foo': 'bar', 'answer': 42} |
||
34 | command = 'run' |
||
35 | meta_info = {'comment': 'test run'} |
||
36 | return { |
||
37 | '_id': 'FEDCBA9876543210', |
||
38 | 'ex_info': exp, |
||
39 | 'command': command, |
||
40 | 'host_info': host, |
||
41 | 'start_time': T1, |
||
42 | 'config': config, |
||
43 | 'meta_info': meta_info, |
||
44 | } |
||
45 | |||
46 | |||
47 | def test_mongo_observer_started_event_creates_run(mongo_obs, sample_run): |
||
48 | sample_run['_id'] = None |
||
49 | _id = mongo_obs.started_event(**sample_run) |
||
50 | assert _id is not None |
||
51 | assert mongo_obs.runs.count() == 1 |
||
52 | db_run = mongo_obs.runs.find_one() |
||
53 | assert db_run == { |
||
54 | '_id': _id, |
||
55 | 'experiment': sample_run['ex_info'], |
||
56 | 'format': mongo_obs.VERSION, |
||
57 | 'command': sample_run['command'], |
||
58 | 'host': sample_run['host_info'], |
||
59 | 'start_time': sample_run['start_time'], |
||
60 | 'heartbeat': None, |
||
61 | 'info': {}, |
||
62 | 'captured_out': '', |
||
63 | 'artifacts': [], |
||
64 | 'config': sample_run['config'], |
||
65 | 'meta': sample_run['meta_info'], |
||
66 | 'status': 'RUNNING', |
||
67 | 'resources': [] |
||
68 | } |
||
69 | |||
70 | |||
71 | def test_mongo_observer_started_event_uses_given_id(mongo_obs, sample_run): |
||
72 | _id = mongo_obs.started_event(**sample_run) |
||
73 | assert _id == sample_run['_id'] |
||
74 | assert mongo_obs.runs.count() == 1 |
||
75 | db_run = mongo_obs.runs.find_one() |
||
76 | assert db_run['_id'] == sample_run['_id'] |
||
77 | |||
78 | |||
79 | def test_mongo_observer_equality(mongo_obs): |
||
80 | runs = mongo_obs.runs |
||
81 | fs = mock.MagicMock() |
||
82 | m = MongoObserver(runs, fs) |
||
83 | assert mongo_obs == m |
||
84 | assert not mongo_obs != m |
||
85 | |||
86 | assert not mongo_obs == 'foo' |
||
87 | View Code Duplication | assert mongo_obs != 'foo' |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
88 | |||
89 | |||
90 | def test_mongo_observer_heartbeat_event_updates_run(mongo_obs, sample_run): |
||
91 | mongo_obs.started_event(**sample_run) |
||
92 | |||
93 | info = {'my_info': [1, 2, 3], 'nr': 7} |
||
94 | outp = 'some output' |
||
95 | mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, |
||
96 | result=1337) |
||
97 | |||
98 | assert mongo_obs.runs.count() == 1 |
||
99 | db_run = mongo_obs.runs.find_one() |
||
100 | assert db_run['heartbeat'] == T2 |
||
101 | assert db_run['result'] == 1337 |
||
102 | assert db_run['info'] == info |
||
103 | View Code Duplication | assert db_run['captured_out'] == outp |
|
0 ignored issues
–
show
|
|||
104 | |||
105 | |||
106 | def test_mongo_observer_completed_event_updates_run(mongo_obs, sample_run): |
||
107 | mongo_obs.started_event(**sample_run) |
||
108 | |||
109 | mongo_obs.completed_event(stop_time=T2, result=42) |
||
110 | |||
111 | assert mongo_obs.runs.count() == 1 |
||
112 | db_run = mongo_obs.runs.find_one() |
||
113 | assert db_run['stop_time'] == T2 |
||
114 | assert db_run['result'] == 42 |
||
115 | assert db_run['status'] == 'COMPLETED' |
||
116 | |||
117 | |||
118 | def test_mongo_observer_interrupted_event_updates_run(mongo_obs, sample_run): |
||
119 | mongo_obs.started_event(**sample_run) |
||
120 | |||
121 | mongo_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED') |
||
122 | |||
123 | assert mongo_obs.runs.count() == 1 |
||
124 | db_run = mongo_obs.runs.find_one() |
||
125 | assert db_run['stop_time'] == T2 |
||
126 | View Code Duplication | assert db_run['status'] == 'INTERRUPTED' |
|
0 ignored issues
–
show
|
|||
127 | |||
128 | |||
129 | def test_mongo_observer_failed_event_updates_run(mongo_obs, sample_run): |
||
130 | mongo_obs.started_event(**sample_run) |
||
131 | |||
132 | fail_trace = "lots of errors and\nso\non..." |
||
133 | mongo_obs.failed_event(fail_time=T2, |
||
134 | fail_trace=fail_trace) |
||
135 | |||
136 | assert mongo_obs.runs.count() == 1 |
||
137 | db_run = mongo_obs.runs.find_one() |
||
138 | assert db_run['stop_time'] == T2 |
||
139 | assert db_run['status'] == 'FAILED' |
||
140 | assert db_run['fail_trace'] == fail_trace |
||
141 | |||
142 | |||
143 | def test_mongo_observer_artifact_event(mongo_obs, sample_run): |
||
144 | mongo_obs.started_event(**sample_run) |
||
145 | |||
146 | filename = "setup.py" |
||
147 | name = 'mysetup' |
||
148 | |||
149 | mongo_obs.artifact_event(name, filename) |
||
150 | |||
151 | assert mongo_obs.fs.put.called |
||
152 | assert mongo_obs.fs.put.call_args[1]['filename'].endswith(name) |
||
153 | |||
154 | db_run = mongo_obs.runs.find_one() |
||
155 | assert db_run['artifacts'] |
||
156 | |||
157 | |||
158 | def test_mongo_observer_resource_event(mongo_obs, sample_run): |
||
159 | mongo_obs.started_event(**sample_run) |
||
160 | |||
161 | filename = "setup.py" |
||
162 | md5 = get_digest(filename) |
||
163 | |||
164 | mongo_obs.resource_event(filename) |
||
165 | |||
166 | assert mongo_obs.fs.exists.called |
||
167 | mongo_obs.fs.exists.assert_any_call(filename=filename) |
||
168 | |||
169 | db_run = mongo_obs.runs.find_one() |
||
170 | # for some reason py27 returns this as tuples and py36 as lists |
||
171 | assert [tuple(r) for r in db_run['resources']] == [(filename, md5)] |
||
172 | |||
173 | |||
174 | def test_force_bson_encodable_doesnt_change_valid_document(): |
||
175 | d = {'int': 1, 'string': 'foo', 'float': 23.87, 'list': ['a', 1, True], |
||
176 | 'bool': True, 'cr4zy: _but_ [legal) Key!': '$illegal.key.as.value', |
||
177 | 'datetime': datetime.datetime.utcnow(), 'tuple': (1, 2.0, 'three'), |
||
178 | 'none': None} |
||
179 | assert force_bson_encodeable(d) == d |
||
180 | |||
181 | |||
182 | def test_force_bson_encodable_substitutes_illegal_value_with_strings(): |
||
183 | d = { |
||
184 | 'a_module': datetime, |
||
185 | 'some_legal_stuff': {'foo': 'bar', 'baz': [1, 23, 4]}, |
||
186 | 'nested': { |
||
187 | 'dict': { |
||
188 | 'with': { |
||
189 | 'illegal_module': mock |
||
190 | } |
||
191 | } |
||
192 | }, |
||
193 | '$illegal': 'because it starts with a $', |
||
194 | 'il.legal': 'because it contains a .', |
||
195 | 12.7: 'illegal because it is not a string key' |
||
196 | } |
||
197 | expected = { |
||
198 | 'a_module': str(datetime), |
||
199 | 'some_legal_stuff': {'foo': 'bar', 'baz': [1, 23, 4]}, |
||
200 | 'nested': { |
||
201 | 'dict': { |
||
202 | 'with': { |
||
203 | 'illegal_module': str(mock) |
||
204 | } |
||
205 | } |
||
206 | }, |
||
207 | '@illegal': 'because it starts with a $', |
||
208 | 'il,legal': 'because it contains a .', |
||
209 | '12,7': 'illegal because it is not a string key' |
||
210 | } |
||
211 | assert force_bson_encodeable(d) == expected |
||
212 | |||
213 | |||
214 | @pytest.fixture |
||
215 | def logged_metrics(): |
||
216 | return [ |
||
217 | ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1), |
||
218 | ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2), |
||
219 | ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3), |
||
220 | |||
221 | ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100), |
||
222 | ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200), |
||
223 | ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300), |
||
224 | |||
225 | ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10), |
||
226 | ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20), |
||
227 | ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30) |
||
228 | ] |
||
229 | |||
230 | |||
231 | def test_log_metrics(mongo_obs, sample_run, logged_metrics): |
||
232 | """ |
||
233 | Test storing scalar measurements |
||
234 | |||
235 | Test whether measurements logged using _run.metrics.log_scalar_metric |
||
236 | are being stored in the 'metrics' collection |
||
237 | and that the experiment 'info' dictionary contains a valid reference |
||
238 | to the metrics collection for each of the metric. |
||
239 | |||
240 | Metrics are identified by name (e.g.: 'training.loss') and by the |
||
241 | experiment run that produced them. Each metric contains a list of x values |
||
242 | (e.g. iteration step), y values (measured values) and timestamps of when |
||
243 | each of the measurements was taken. |
||
244 | """ |
||
245 | |||
246 | # Start the experiment |
||
247 | mongo_obs.started_event(**sample_run) |
||
248 | |||
249 | # Initialize the info dictionary and standard output with arbitrary values |
||
250 | info = {'my_info': [1, 2, 3], 'nr': 7} |
||
251 | outp = 'some output' |
||
252 | |||
253 | # Take first 6 measured events, group them by metric name |
||
254 | # and store the measured series to the 'metrics' collection |
||
255 | # and reference the newly created records in the 'info' dictionary. |
||
256 | mongo_obs.log_metrics(linearize_metrics(logged_metrics[:6]), info) |
||
257 | # Call standard heartbeat event (store the info dictionary to the database) |
||
258 | mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, |
||
259 | result=0) |
||
260 | |||
261 | # There should be only one run stored |
||
262 | assert mongo_obs.runs.count() == 1 |
||
263 | db_run = mongo_obs.runs.find_one() |
||
264 | # ... and the info dictionary should contain a list of created metrics |
||
265 | assert "metrics" in db_run['info'] |
||
266 | assert type(db_run['info']["metrics"]) == list |
||
267 | |||
268 | # The metrics, stored in the metrics collection, |
||
269 | # should be two (training.loss and training.accuracy) |
||
270 | assert mongo_obs.metrics.count() == 2 |
||
271 | # Read the training.loss metric and make sure it references the correct run |
||
272 | # and that the run (in the info dictionary) references the correct metric record. |
||
273 | loss = mongo_obs.metrics.find_one({"name": "training.loss", "run_id": db_run['_id']}) |
||
274 | assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info']["metrics"] |
||
275 | assert loss["steps"] == [10, 20, 30] |
||
276 | assert loss["values"] == [1, 2, 3] |
||
277 | for i in range(len(loss["timestamps"]) - 1): |
||
278 | assert loss["timestamps"][i] <= loss["timestamps"][i + 1] |
||
279 | |||
280 | # Read the training.accuracy metric and check the references as with the training.loss above |
||
281 | accuracy = mongo_obs.metrics.find_one({"name": "training.accuracy", "run_id": db_run['_id']}) |
||
282 | assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run['info']["metrics"] |
||
283 | assert accuracy["steps"] == [10, 20, 30] |
||
284 | assert accuracy["values"] == [100, 200, 300] |
||
285 | |||
286 | # Now, process the remaining events |
||
287 | # The metrics shouldn't be overwritten, but appended instead. |
||
288 | mongo_obs.log_metrics(linearize_metrics(logged_metrics[6:]), info) |
||
289 | mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2, |
||
290 | result=0) |
||
291 | |||
292 | assert mongo_obs.runs.count() == 1 |
||
293 | db_run = mongo_obs.runs.find_one() |
||
294 | assert "metrics" in db_run['info'] |
||
295 | |||
296 | # The newly added metrics belong to the same run and have the same names, so the total number |
||
297 | # of metrics should not change. |
||
298 | assert mongo_obs.metrics.count() == 2 |
||
299 | loss = mongo_obs.metrics.find_one({"name": "training.loss", "run_id": db_run['_id']}) |
||
300 | assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info']["metrics"] |
||
301 | # ... but the values should be appended to the original list |
||
302 | assert loss["steps"] == [10, 20, 30, 40, 50, 60] |
||
303 | assert loss["values"] == [1, 2, 3, 10, 20, 30] |
||
304 | for i in range(len(loss["timestamps"]) - 1): |
||
305 | assert loss["timestamps"][i] <= loss["timestamps"][i + 1] |
||
306 | |||
307 | accuracy = mongo_obs.metrics.find_one({"name": "training.accuracy", "run_id": db_run['_id']}) |
||
308 | assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run['info']["metrics"] |
||
309 | assert accuracy["steps"] == [10, 20, 30] |
||
310 | assert accuracy["values"] == [100, 200, 300] |
||
311 | |||
312 | # Make sure that when starting a new experiment, new records in metrics are created |
||
313 | # instead of appending to the old ones. |
||
314 | sample_run["_id"] = "NEWID" |
||
315 | # Start the experiment |
||
316 | mongo_obs.started_event(**sample_run) |
||
317 | mongo_obs.log_metrics(linearize_metrics(logged_metrics[:4]), info) |
||
318 | mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1, |
||
319 | result=0) |
||
320 | # A new run has been created |
||
321 | assert mongo_obs.runs.count() == 2 |
||
322 | # Another 2 metrics have been created |
||
323 | assert mongo_obs.metrics.count() == 4 |