test_log_metrics()   F
last analyzed

Complexity

Conditions 23

Size

Total Lines 80

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 23
c 2
b 0
f 0
dl 0
loc 80
rs 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like test_log_metrics() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
import datetime
5
import hashlib
6
import os
7
import tempfile
8
from copy import copy
9
import pytest
10
import json
11
12
from sacred.observers.file_storage import FileStorageObserver
13
from sacred.serializer import restore
14
from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics
15
16
17
T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0)
18
T2 = datetime.datetime(1999, 5, 5, 5, 5, 5, 5)
19
20
21
@pytest.fixture()
22
def sample_run():
23
    exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'}
24
    host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'}
25
    config = {'config': 'True', 'foo': 'bar', 'answer': 42}
26
    command = 'run'
27
    meta_info = {'comment': 'test run'}
28
    return {
29
        '_id': 'FEDCBA9876543210',
30
        'ex_info': exp,
31
        'command': command,
32
        'host_info': host,
33
        'start_time': T1,
34
        'config': config,
35
        'meta_info': meta_info,
36
    }
37
38
39
@pytest.fixture()
40
def dir_obs(tmpdir):
41
    return tmpdir, FileStorageObserver.create(tmpdir.strpath)
42
43
44 View Code Duplication
@pytest.fixture
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
45
def tmpfile():
46
    # NOTE: instead of using a with block and delete=True we are creating and
47
    # manually deleting the file, such that we can close it before running the
48
    # tests. This is necessary since on Windows we can not open the same file
49
    # twice, so for the FileStorageObserver to read it, we need to close it.
50
    f = tempfile.NamedTemporaryFile(suffix='.py', delete=False)
51
52
    f.content = 'import sacred\n'
53
    f.write(f.content.encode())
54
    f.flush()
55
    f.seek(0)
56
    f.md5sum = hashlib.md5(f.read()).hexdigest()
57
58
    f.close()
59
60
    yield f
61
62
    os.remove(f.name)
63
64
65
def test_fs_observer_started_event_creates_rundir(dir_obs, sample_run):
66
    basedir, obs = dir_obs
67
    sample_run['_id'] = None
68
    _id = obs.started_event(**sample_run)
69
    assert _id is not None
70
    run_dir = basedir.join(str(_id))
71
    assert run_dir.exists()
72
    assert run_dir.join('cout.txt').exists()
73
    config = json.loads(run_dir.join('config.json').read())
74
    assert config == sample_run['config']
75
76
    run = json.loads(run_dir.join('run.json').read())
77
    assert run == {
78
        'experiment': sample_run['ex_info'],
79
        'command': sample_run['command'],
80
        'host': sample_run['host_info'],
81
        'start_time': T1.isoformat(),
82
        'heartbeat': None,
83
        'meta': sample_run['meta_info'],
84
        "resources": [],
85
        "artifacts": [],
86
        "status": "RUNNING"
87
    }
88
89
90
def test_fs_observer_started_event_stores_source(dir_obs, sample_run, tmpfile):
91
    basedir, obs = dir_obs
92
    sample_run['ex_info']['sources'] = [[tmpfile.name, tmpfile.md5sum]]
93
94
    _id = obs.started_event(**sample_run)
95
    run_dir = basedir.join(_id)
96
97
    assert run_dir.exists()
98
    run = json.loads(run_dir.join('run.json').read())
99
    ex_info = copy(run['experiment'])
100
    assert ex_info['sources'][0][0] == tmpfile.name
101
    source_path = ex_info['sources'][0][1]
102
    source = basedir.join(source_path)
103
    assert source.exists()
104
    assert source.read() == 'import sacred\n'
105
106
107
def test_fs_observer_started_event_uses_given_id(dir_obs, sample_run):
108
    basedir, obs = dir_obs
109
    _id = obs.started_event(**sample_run)
110
    assert _id == sample_run['_id']
111
    assert basedir.join(_id).exists()
112
113
114
def test_fs_observer_heartbeat_event_updates_run(dir_obs, sample_run):
115
    basedir, obs = dir_obs
116
    _id = obs.started_event(**sample_run)
117
    run_dir = basedir.join(_id)
118
    info = {'my_info': [1, 2, 3], 'nr': 7}
119
    obs.heartbeat_event(info=info, captured_out='some output', beat_time=T2,
120
                        result=17)
121
122
    assert run_dir.join('cout.txt').read() == 'some output'
123
    run = json.loads(run_dir.join('run.json').read())
124
125
    assert run['heartbeat'] == T2.isoformat()
126
    assert run['result'] == 17
127
128
    assert run_dir.join('info.json').exists()
129
    i = json.loads(run_dir.join('info.json').read())
130
    assert info == i
131
132
133
def test_fs_observer_completed_event_updates_run(dir_obs, sample_run):
134
    basedir, obs = dir_obs
135
    _id = obs.started_event(**sample_run)
136
    run_dir = basedir.join(_id)
137
138
    obs.completed_event(stop_time=T2, result=42)
139
140
    run = json.loads(run_dir.join('run.json').read())
141
    assert run['stop_time'] == T2.isoformat()
142
    assert run['status'] == 'COMPLETED'
143
    assert run['result'] == 42
144
145
146
def test_fs_observer_interrupted_event_updates_run(dir_obs, sample_run):
147
    basedir, obs = dir_obs
148
    _id = obs.started_event(**sample_run)
149
    run_dir = basedir.join(_id)
150
151
    obs.interrupted_event(interrupt_time=T2, status='CUSTOM_INTERRUPTION')
152
153
    run = json.loads(run_dir.join('run.json').read())
154
    assert run['stop_time'] == T2.isoformat()
155
    assert run['status'] == 'CUSTOM_INTERRUPTION'
156
157
158
def test_fs_observer_failed_event_updates_run(dir_obs, sample_run):
159
    basedir, obs = dir_obs
160
    _id = obs.started_event(**sample_run)
161
    run_dir = basedir.join(_id)
162
163
    fail_trace = "lots of errors and\nso\non..."
164
    obs.failed_event(fail_time=T2, fail_trace=fail_trace)
165
166
    run = json.loads(run_dir.join('run.json').read())
167
    assert run['stop_time'] == T2.isoformat()
168
    assert run['status'] == 'FAILED'
169
    assert run['fail_trace'] == fail_trace
170
171
172
def test_fs_observer_artifact_event(dir_obs, sample_run, tmpfile):
173
    basedir, obs = dir_obs
174
    _id = obs.started_event(**sample_run)
175
    run_dir = basedir.join(_id)
176
    
177
    obs.artifact_event('my_artifact.py', tmpfile.name)
178
179
    artifact = run_dir.join('my_artifact.py')
180
    assert artifact.exists()
181
    assert artifact.read() == tmpfile.content
182
183
    run = json.loads(run_dir.join('run.json').read())
184
    assert len(run['artifacts']) == 1
185
    assert run['artifacts'][0] == artifact.relto(run_dir)
186
187
188
def test_fs_observer_resource_event(dir_obs, sample_run, tmpfile):
189
    basedir, obs = dir_obs
190
    _id = obs.started_event(**sample_run)
191
    run_dir = basedir.join(_id)
192
193
    obs.resource_event(tmpfile.name)
194
195
    res_dir = basedir.join('_resources')
196
    assert res_dir.exists()
197
    assert len(res_dir.listdir()) == 1
198
    assert res_dir.listdir()[0].read() == tmpfile.content
199
200
    run = json.loads(run_dir.join('run.json').read())
201
    assert len(run['resources']) == 1
202
    assert run['resources'][0] == [tmpfile.name, res_dir.listdir()[0].strpath]
203
204
205
def test_fs_observer_resource_event_does_not_duplicate(dir_obs, sample_run,
206
                                                       tmpfile):
207
    basedir, obs = dir_obs
208
    obs2 = FileStorageObserver.create(obs.basedir)
209
    obs.started_event(**sample_run)
210
211
    obs.resource_event(tmpfile.name)
212
    # let's have another run from a different observer
213
    sample_run['_id'] = None
214
    _id = obs2.started_event(**sample_run)
215
    run_dir = basedir.join(str(_id))
216
    obs2.resource_event(tmpfile.name)
217
218
    res_dir = basedir.join('_resources')
219
    assert res_dir.exists()
220
    assert len(res_dir.listdir()) == 1
221
    assert res_dir.listdir()[0].read() == tmpfile.content
222
223
    run = json.loads(run_dir.join('run.json').read())
224
    assert len(run['resources']) == 1
225
    assert run['resources'][0] == [tmpfile.name, res_dir.listdir()[0].strpath]
226
227
228
def test_fs_observer_equality(dir_obs):
229
    basedir, obs = dir_obs
230
    obs2 = FileStorageObserver.create(obs.basedir)
231
    assert obs == obs2
232
    assert not obs != obs2
233
234
    assert not obs == 'foo'
235
    assert obs != 'foo'
236
237 View Code Duplication
@pytest.fixture
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
238
def logged_metrics():
239
    return [
240
        ScalarMetricLogEntry("training.loss", 10, datetime.datetime.utcnow(), 1),
241
        ScalarMetricLogEntry("training.loss", 20, datetime.datetime.utcnow(), 2),
242
        ScalarMetricLogEntry("training.loss", 30, datetime.datetime.utcnow(), 3),
243
244
        ScalarMetricLogEntry("training.accuracy", 10, datetime.datetime.utcnow(), 100),
245
        ScalarMetricLogEntry("training.accuracy", 20, datetime.datetime.utcnow(), 200),
246
        ScalarMetricLogEntry("training.accuracy", 30, datetime.datetime.utcnow(), 300),
247
248
        ScalarMetricLogEntry("training.loss", 40, datetime.datetime.utcnow(), 10),
249
        ScalarMetricLogEntry("training.loss", 50, datetime.datetime.utcnow(), 20),
250
        ScalarMetricLogEntry("training.loss", 60, datetime.datetime.utcnow(), 30)
251
    ]
252
253
254
def test_log_metrics(dir_obs, sample_run, logged_metrics):
255
    """Test storing of scalar measurements.
256
257
    Test whether measurements logged using _run.metrics.log_scalar_metric
258
    are being stored in the metrics.json file.
259
260
    Metrics are stored as a json with each metric indexed by a name 
261
    (e.g.: 'training.loss'). Each metric for the given name is then
262
    stored as three lists: iteration step(steps), the values logged(values)
263
    and the timestamp at which the measurement was taken(timestamps)
264
    """
265
266
    # Start the experiment 
267
    basedir, obs = dir_obs
268
    sample_run['_id'] = None
269
    _id = obs.started_event(**sample_run)    
270
    run_dir = basedir.join(str(_id))
271
272
    # Initialize the info dictionary and standard output with arbitrary values
273
    info = {'my_info': [1, 2, 3], 'nr': 7}
274
    outp = 'some output'
275
276
    obs.log_metrics(linearize_metrics(logged_metrics[:6]), info)
277
    obs.heartbeat_event(info=info, captured_out=outp, beat_time=T1,
278
                              result=0)
279
280
281
    assert run_dir.join('metrics.json').exists()
282
    metrics = json.loads(run_dir.join('metrics.json').read())
283
284
285
    # Confirm that we have only two metric names registered.
286
    # and they have all the information we need.
287
    assert len(metrics) == 2
288
    assert "training.loss" in metrics
289
    assert "training.accuracy" in metrics
290
    for v in ["steps","values","timestamps"]:
291
        assert v in metrics["training.loss"] 
292
        assert v in metrics["training.accuracy"]
293
294
295
    # Verify they have all the information 
296
    # we logged in the right order.
297
    loss = metrics["training.loss"]
298
    assert loss["steps"] == [10, 20, 30]
299
    assert loss["values"] == [1, 2, 3]
300
    for i in range(len(loss["timestamps"]) - 1):
301
        assert loss["timestamps"][i] <= loss["timestamps"][i + 1]
302
303
    accuracy = metrics["training.accuracy"]
304
    assert accuracy["steps"] == [10, 20, 30]
305
    assert accuracy["values"] == [100, 200, 300]
306
307
308
    # Now, process the remaining events
309
    # The metrics shouldn't be overwritten, but appended instead.
310
    obs.log_metrics(linearize_metrics(logged_metrics[6:]), info)
311
    obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2,
312
                              result=0)
313
314
    # Reload the new metrics
315
    metrics = json.loads(run_dir.join('metrics.json').read())
316
317
    # The newly added metrics belong to the same run and have the same names,
318
    # so the total number of metrics should not change.
319
    assert len(metrics) == 2
320
321
    assert "training.loss" in metrics
322
    loss = metrics["training.loss"]
323
    assert loss["steps"] == [10, 20, 30, 40, 50, 60]
324
    assert loss["values"] == [1, 2, 3, 10, 20, 30]
325
    for i in range(len(loss["timestamps"]) - 1):
326
        assert loss["timestamps"][i] <= loss["timestamps"][i + 1]
327
328
329
    # Read the training.accuracy metric and verify it's unchanged
330
    assert "training.accuracy" in metrics
331
    accuracy = metrics["training.accuracy"]
332
    assert accuracy["steps"] == [10, 20, 30]
333
    assert accuracy["values"] == [100, 200, 300]
334