Completed
Push — master ( 6c3661...416f46 )
by Klaus
34s
created

tests/test_observers/test_mongo_observer.py (3 issues)

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
import datetime
5
import mock
6
import pytest
7
8
from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics
9
10
pymongo = pytest.importorskip("pymongo")
11
mongomock = pytest.importorskip("mongomock")
12
13
from sacred.dependencies import get_digest
14
from sacred.observers.mongo import (MongoObserver, force_bson_encodeable)
15
16
T1 = datetime.datetime(1999, 5, 4, 3, 2, 1)
17
T2 = datetime.datetime(1999, 5, 5, 5, 5, 5)
18
19
20
@pytest.fixture
21
def mongo_obs():
22
    db = mongomock.MongoClient().db
23
    runs = db.runs
24
    metrics = db.metrics
25
    fs = mock.MagicMock()
26
    return MongoObserver(runs, fs, metrics_collection=metrics)
27
28
29
@pytest.fixture()
30
def sample_run():
31
    exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'}
32
    host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'}
33
    config = {'config': 'True', 'foo': 'bar', 'answer': 42}
34
    command = 'run'
35
    meta_info = {'comment': 'test run'}
36
    return {
37
        '_id': 'FEDCBA9876543210',
38
        'ex_info': exp,
39
        'command': command,
40
        'host_info': host,
41
        'start_time': T1,
42
        'config': config,
43
        'meta_info': meta_info,
44
    }
45
46
47
def test_mongo_observer_started_event_creates_run(mongo_obs, sample_run):
48
    sample_run['_id'] = None
49
    _id = mongo_obs.started_event(**sample_run)
50
    assert _id is not None
51
    assert mongo_obs.runs.count() == 1
52
    db_run = mongo_obs.runs.find_one()
53
    assert db_run == {
54
        '_id': _id,
55
        'experiment': sample_run['ex_info'],
56
        'format': mongo_obs.VERSION,
57
        'command': sample_run['command'],
58
        'host': sample_run['host_info'],
59
        'start_time': sample_run['start_time'],
60
        'heartbeat': None,
61
        'info': {},
62
        'captured_out': '',
63
        'artifacts': [],
64
        'config': sample_run['config'],
65
        'meta': sample_run['meta_info'],
66
        'status': 'RUNNING',
67
        'resources': []
68
    }
69
70
71
def test_mongo_observer_started_event_uses_given_id(mongo_obs, sample_run):
72
    _id = mongo_obs.started_event(**sample_run)
73
    assert _id == sample_run['_id']
74
    assert mongo_obs.runs.count() == 1
75
    db_run = mongo_obs.runs.find_one()
76
    assert db_run['_id'] == sample_run['_id']
77
78
79
def test_mongo_observer_equality(mongo_obs):
80
    runs = mongo_obs.runs
81
    fs = mock.MagicMock()
82
    m = MongoObserver(runs, fs)
83
    assert mongo_obs == m
84
    assert not mongo_obs != m
85
86
    assert not mongo_obs == 'foo'
87 View Code Duplication
    assert mongo_obs != 'foo'
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
88
89
90
def test_mongo_observer_heartbeat_event_updates_run(mongo_obs, sample_run):
91
    mongo_obs.started_event(**sample_run)
92
93
    info = {'my_info': [1, 2, 3], 'nr': 7}
94
    outp = 'some output'
95
    mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2,
96
                              result=1337)
97
98
    assert mongo_obs.runs.count() == 1
99
    db_run = mongo_obs.runs.find_one()
100
    assert db_run['heartbeat'] == T2
101
    assert db_run['result'] == 1337
102
    assert db_run['info'] == info
103 View Code Duplication
    assert db_run['captured_out'] == outp
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
104
105
106
def test_mongo_observer_completed_event_updates_run(mongo_obs, sample_run):
107
    mongo_obs.started_event(**sample_run)
108
109
    mongo_obs.completed_event(stop_time=T2, result=42)
110
111
    assert mongo_obs.runs.count() == 1
112
    db_run = mongo_obs.runs.find_one()
113
    assert db_run['stop_time'] == T2
114
    assert db_run['result'] == 42
115
    assert db_run['status'] == 'COMPLETED'
116
117
118
def test_mongo_observer_interrupted_event_updates_run(mongo_obs, sample_run):
119
    mongo_obs.started_event(**sample_run)
120
121
    mongo_obs.interrupted_event(interrupt_time=T2, status='INTERRUPTED')
122
123
    assert mongo_obs.runs.count() == 1
124
    db_run = mongo_obs.runs.find_one()
125
    assert db_run['stop_time'] == T2
126 View Code Duplication
    assert db_run['status'] == 'INTERRUPTED'
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
127
128
129
def test_mongo_observer_failed_event_updates_run(mongo_obs, sample_run):
130
    mongo_obs.started_event(**sample_run)
131
132
    fail_trace = "lots of errors and\nso\non..."
133
    mongo_obs.failed_event(fail_time=T2,
134
                           fail_trace=fail_trace)
135
136
    assert mongo_obs.runs.count() == 1
137
    db_run = mongo_obs.runs.find_one()
138
    assert db_run['stop_time'] == T2
139
    assert db_run['status'] == 'FAILED'
140
    assert db_run['fail_trace'] == fail_trace
141
142
143
def test_mongo_observer_artifact_event(mongo_obs, sample_run):
144
    mongo_obs.started_event(**sample_run)
145
146
    filename = "setup.py"
147
    name = 'mysetup'
148
149
    mongo_obs.artifact_event(name, filename)
150
151
    assert mongo_obs.fs.put.called
152
    assert mongo_obs.fs.put.call_args[1]['filename'].endswith(name)
153
154
    db_run = mongo_obs.runs.find_one()
155
    assert db_run['artifacts']
156
157
158
def test_mongo_observer_resource_event(mongo_obs, sample_run):
159
    mongo_obs.started_event(**sample_run)
160
161
    filename = "setup.py"
162
    md5 = get_digest(filename)
163
164
    mongo_obs.resource_event(filename)
165
166
    assert mongo_obs.fs.exists.called
167
    mongo_obs.fs.exists.assert_any_call(filename=filename)
168
169
    db_run = mongo_obs.runs.find_one()
170
    # for some reason py27 returns this as tuples and py36 as lists
171
    assert [tuple(r) for r in db_run['resources']] == [(filename, md5)]
172
173
174
def test_force_bson_encodable_doesnt_change_valid_document():
175
    d = {'int': 1, 'string': 'foo', 'float': 23.87, 'list': ['a', 1, True],
176
         'bool': True, 'cr4zy: _but_ [legal) Key!': '$illegal.key.as.value',
177
         'datetime': datetime.datetime.utcnow(), 'tuple': (1, 2.0, 'three'),
178
         'none': None}
179
    assert force_bson_encodeable(d) == d
180
181
182
def test_force_bson_encodable_substitutes_illegal_value_with_strings():
183
    d = {
184
        'a_module': datetime,
185
        'some_legal_stuff': {'foo': 'bar', 'baz': [1, 23, 4]},
186
        'nested': {
187
            'dict': {
188
                'with': {
189
                    'illegal_module': mock
190
                }
191
            }
192
        },
193
        '$illegal': 'because it starts with a $',
194
        'il.legal': 'because it contains a .',
195
        12.7: 'illegal because it is not a string key'
196
    }
197
    expected = {
198
        'a_module': str(datetime),
199
        'some_legal_stuff': {'foo': 'bar', 'baz': [1, 23, 4]},
200
        'nested': {
201
            'dict': {
202
                'with': {
203
                    'illegal_module': str(mock)
204
                }
205
            }
206
        },
207
        '@illegal': 'because it starts with a $',
208
        'il,legal': 'because it contains a .',
209
        '12,7': 'illegal because it is not a string key'
210
    }
211
    assert force_bson_encodeable(d) == expected
212
213
214
@pytest.fixture
215
def logged_metrics():
216
    return [
217
        ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1),
218
        ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2),
219
        ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3),
220
221
        ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100),
222
        ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200),
223
        ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300),
224
225
        ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10),
226
        ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20),
227
        ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30)
228
    ]
229
230
231
def test_log_metrics(mongo_obs, sample_run, logged_metrics):
232
    """
233
    Test storing scalar measurements
234
    
235
    Test whether measurements logged using _run.metrics.log_scalar_metric
236
    are being stored in the 'metrics' collection
237
    and that the experiment 'info' dictionary contains a valid reference 
238
    to the metrics collection for each of the metric.
239
    
240
    Metrics are identified by name (e.g.: 'training.loss') and by the 
241
    experiment run that produced them. Each metric contains a list of x values
242
    (e.g. iteration step), y values (measured values) and timestamps of when 
243
    each of the measurements was taken.
244
    """
245
246
    # Start the experiment
247
    mongo_obs.started_event(**sample_run)
248
249
    # Initialize the info dictionary and standard output with arbitrary values
250
    info = {'my_info': [1, 2, 3], 'nr': 7}
251
    outp = 'some output'
252
253
    # Take first 6 measured events, group them by metric name
254
    # and store the measured series to the 'metrics' collection
255
    # and reference the newly created records in the 'info' dictionary.
256
    mongo_obs.log_metrics(linearize_metrics(logged_metrics[:6]), info)
257
    # Call standard heartbeat event (store the info dictionary to the database)
258
    mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1,
259
                              result=0)
260
261
    # There should be only one run stored
262
    assert mongo_obs.runs.count() == 1
263
    db_run = mongo_obs.runs.find_one()
264
    # ... and the info dictionary should contain a list of created metrics
265
    assert "metrics" in db_run['info']
266
    assert type(db_run['info']["metrics"]) == list
267
268
    # The metrics, stored in the metrics collection,
269
    # should be two (training.loss and training.accuracy)
270
    assert mongo_obs.metrics.count() == 2
271
    # Read the training.loss metric and make sure it references the correct run
272
    # and that the run (in the info dictionary) references the correct metric record.
273
    loss = mongo_obs.metrics.find_one({"name": "training.loss", "run_id": db_run['_id']})
274
    assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info']["metrics"]
275
    assert loss["steps"] == [10, 20, 30]
276
    assert loss["values"] == [1, 2, 3]
277
    for i in range(len(loss["timestamps"]) - 1):
278
        assert loss["timestamps"][i] <= loss["timestamps"][i + 1]
279
280
    # Read the training.accuracy metric and check the references as with the training.loss above
281
    accuracy = mongo_obs.metrics.find_one({"name": "training.accuracy", "run_id": db_run['_id']})
282
    assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run['info']["metrics"]
283
    assert accuracy["steps"] == [10, 20, 30]
284
    assert accuracy["values"] == [100, 200, 300]
285
286
    # Now, process the remaining events
287
    # The metrics shouldn't be overwritten, but appended instead.
288
    mongo_obs.log_metrics(linearize_metrics(logged_metrics[6:]), info)
289
    mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2,
290
                              result=0)
291
292
    assert mongo_obs.runs.count() == 1
293
    db_run = mongo_obs.runs.find_one()
294
    assert "metrics" in db_run['info']
295
296
    # The newly added metrics belong to the same run and have the same names, so the total number
297
    # of metrics should not change.
298
    assert mongo_obs.metrics.count() == 2
299
    loss = mongo_obs.metrics.find_one({"name": "training.loss", "run_id": db_run['_id']})
300
    assert {"name": "training.loss", "id": str(loss["_id"])} in db_run['info']["metrics"]
301
    # ... but the values should be appended to the original list
302
    assert loss["steps"] == [10, 20, 30, 40, 50, 60]
303
    assert loss["values"] == [1, 2, 3, 10, 20, 30]
304
    for i in range(len(loss["timestamps"]) - 1):
305
        assert loss["timestamps"][i] <= loss["timestamps"][i + 1]
306
307
    accuracy = mongo_obs.metrics.find_one({"name": "training.accuracy", "run_id": db_run['_id']})
308
    assert {"name": "training.accuracy", "id": str(accuracy["_id"])} in db_run['info']["metrics"]
309
    assert accuracy["steps"] == [10, 20, 30]
310
    assert accuracy["values"] == [100, 200, 300]
311
312
    # Make sure that when starting a new experiment, new records in metrics are created
313
    # instead of appending to the old ones.
314
    sample_run["_id"] = "NEWID"
315
    # Start the experiment
316
    mongo_obs.started_event(**sample_run)
317
    mongo_obs.log_metrics(linearize_metrics(logged_metrics[:4]), info)
318
    mongo_obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1,
319
                              result=0)
320
    # A new run has been created
321
    assert mongo_obs.runs.count() == 2
322
    # Another 2 metrics have been created
323
    assert mongo_obs.metrics.count() == 4