Completed
Push — master ( 50b00b...d515a4 )
by Klaus
11s
created

tests/test_run.py (3 issues)

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
from datetime import datetime
5
import mock
6
import os
7
import pytest
8
import tempfile
9
import sys
10
11
from sacred.run import Run
12
from sacred.config.config_summary import ConfigSummary
13
from sacred.utils import (ObserverError, SacredInterrupt, TimeoutInterrupt,
14
                          apply_backspaces_and_linefeeds)
15
16
17
@pytest.fixture
18
def run():
19
    config = {'a': 17, 'foo': {'bar': True, 'baz': False}, 'seed': 1234}
20
    config_mod = ConfigSummary()
21
    signature = mock.Mock()
22
    signature.name = 'main_func'
23
    main_func = mock.Mock(return_value=123, prefix='', signature=signature)
24
    logger = mock.Mock()
25
    observer = [mock.Mock(priority=10)]
26
    return Run(config, config_mod, main_func, observer, logger, logger, {},
27
               {}, [], [])
28
29
30
def test_run_attributes(run):
31
    assert isinstance(run.config, dict)
32
    assert isinstance(run.config_modifications, ConfigSummary)
33
    assert isinstance(run.experiment_info, dict)
34
    assert isinstance(run.host_info, dict)
35
    assert isinstance(run.info, dict)
36
37
38
def test_run_state_attributes(run):
39
    assert run.start_time is None
40
    assert run.stop_time is None
41
    assert run.captured_out == ''
42
    assert run.result is None
43
44
45
def test_run_run(run):
46
    assert run() == 123
47
    assert (run.start_time - datetime.utcnow()).total_seconds() < 1
48
    assert (run.stop_time - datetime.utcnow()).total_seconds() < 1
49
    assert run.result == 123
50
    assert run.captured_out == ''
51
52
53
def test_run_emits_events_if_successful(run):
54
    run()
55
56
    observer = run.observers[0]
57
    assert observer.started_event.called
58
    assert observer.heartbeat_event.called
59
    assert observer.completed_event.called
60
    assert not observer.interrupted_event.called
61
    assert not observer.failed_event.called
62
63
64
@pytest.mark.parametrize('exception,status', [
65
    (KeyboardInterrupt, 'INTERRUPTED'),
66
    (SacredInterrupt, 'INTERRUPTED'),
67
    (TimeoutInterrupt, 'TIMEOUT'),
68
])
69
def test_run_emits_events_if_interrupted(run, exception, status):
70
    observer = run.observers[0]
71
    run.main_function.side_effect = exception
72
    with pytest.raises(exception):
73
        run()
74
    assert observer.started_event.called
75
    assert observer.heartbeat_event.called
76
    assert not observer.completed_event.called
77
    assert observer.interrupted_event.called
78
    observer.interrupted_event.assert_called_with(
79
        interrupt_time=run.stop_time,
80
        status=status)
81
    assert not observer.failed_event.called
82
83
84
def test_run_emits_events_if_failed(run):
85
    observer = run.observers[0]
86
    run.main_function.side_effect = TypeError
87
    with pytest.raises(TypeError):
88
        run()
89
    assert observer.started_event.called
90
    assert observer.heartbeat_event.called
91
    assert not observer.completed_event.called
92
    assert not observer.interrupted_event.called
93
    assert observer.failed_event.called
94
95
96
def test_run_started_event(run):
97
    observer = run.observers[0]
98
    run()
99
    observer.started_event.assert_called_with(
100
        command='main_func',
101
        ex_info=run.experiment_info,
102
        host_info=run.host_info,
103
        start_time=run.start_time,
104
        config=run.config,
105
        meta_info={},
106
        _id=None
107
    )
108
109
110
def test_run_completed_event(run):
111
    observer = run.observers[0]
112
    run()
113
    observer.completed_event.assert_called_with(
114
        stop_time=run.stop_time,
115
        result=run.result
116
    )
117
118
119
def test_run_heartbeat_event(run):
120
    observer = run.observers[0]
121
    run.info['test'] = 321
122
    run()
123
    call_args, call_kwargs = observer.heartbeat_event.call_args_list[0]
124
    assert call_kwargs['info'] == run.info
125
    assert call_kwargs['captured_out'] == ""
126
    assert (call_kwargs['beat_time'] - datetime.utcnow()).total_seconds() < 1
127
128
129
def test_run_artifact_event(run):
130
    observer = run.observers[0]
131
    handle, f_name = tempfile.mkstemp()
132
    run.add_artifact(f_name, name='foobar')
133
    observer.artifact_event.assert_called_with(filename=f_name, name='foobar')
134
    os.close(handle)
135
    os.remove(f_name)
136
137
138
def test_run_resource_event(run):
139
    observer = run.observers[0]
140
    handle, f_name = tempfile.mkstemp()
141
    run.open_resource(f_name)
142
    observer.resource_event.assert_called_with(filename=f_name)
143
    os.close(handle)
144
    os.remove(f_name)
145
146
147
def test_run_cannot_be_started_twice(run):
148
    run()
149
    with pytest.raises(RuntimeError):
150
        run()
151
152
153
def test_run_observer_failure_on_startup_not_caught(run):
154
    observer = run.observers[0]
155
    observer.started_event.side_effect = ObserverError
156
    with pytest.raises(ObserverError):
157
        run()
158
159
160
def test_run_observer_error_in_heartbeat_is_caught(run):
161
    observer = run.observers[0]
162
    observer.heartbeat_event.side_effect = TypeError
163
    run()
164
    assert observer in run._failed_observers
165
    assert observer.started_event.called
166
    assert observer.heartbeat_event.called
167
    assert observer.completed_event.called
168
169
170
def test_run_exception_in_completed_event_is_caught(run):
171
    observer = run.observers[0]
172
    observer2 = mock.Mock(priority=20)
173
    run.observers.append(observer2)
174
    observer.completed_event.side_effect = TypeError
175
    run()
176
    assert observer.completed_event.called
177
    assert observer2.completed_event.called
178
179
180
def test_run_exception_in_interrupted_event_is_caught(run):
181
    observer = run.observers[0]
182
    observer2 = mock.Mock(priority=20)
183
    run.observers.append(observer2)
184
    observer.interrupted_event.side_effect = TypeError
185
    run.main_function.side_effect = KeyboardInterrupt
186
    with pytest.raises(KeyboardInterrupt):
187
        run()
188
    assert observer.interrupted_event.called
189
    assert observer2.interrupted_event.called
190
191
192
def test_run_exception_in_failed_event_is_caught(run):
193 View Code Duplication
    observer = run.observers[0]
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
194
    observer2 = mock.Mock(priority=20)
195
    run.observers.append(observer2)
196
    observer.failed_event.side_effect = TypeError
197
    run.main_function.side_effect = AttributeError
198
    with pytest.raises(AttributeError):
199
        run()
200
    assert observer.failed_event.called
201
    assert observer2.failed_event.called
202
203
204
def test_unobserved_run_doesnt_emit(run):
205 View Code Duplication
    observer = run.observers[0]
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
206
    run.unobserved = True
207
    run()
208
    assert not observer.started_event.called
209
    assert not observer.heartbeat_event.called
210
    assert not observer.completed_event.called
211
    assert not observer.interrupted_event.called
212
    assert not observer.failed_event.called
213
214
215
def test_stdout_capturing_no(run, capsys):
216
    def print_mock_progress():
217
        for i in range(10):
218
            print(i, end="")
219
        sys.stdout.flush()
220
221
    run.main_function.side_effect = print_mock_progress
222
    run.capture_mode = "no"
223
    with capsys.disabled():
224
        run()
225
    assert run.captured_out == ''
226
227
228 View Code Duplication
def test_stdout_capturing_sys(run, capsys):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
229
    def print_mock_progress():
230
        for i in range(10):
231
            print(i, end="")
232
        sys.stdout.flush()
233
234
    run.main_function.side_effect = print_mock_progress
235
    run.capture_mode = "sys"
236
    with capsys.disabled():
237
        run()
238
    assert run.captured_out == '0123456789'
239
240
241
@pytest.mark.skipif(sys.platform.startswith('win'),
242
                    reason="does not work on windows")
243
def test_stdout_capturing_fd(run, capsys):
244
    def print_mock_progress():
245
        for i in range(10):
246
            print(i, end="")
247
        sys.stdout.flush()
248
249
    run.main_function.side_effect = print_mock_progress
250
    run.capture_mode = "fd"
251
    with capsys.disabled():
252
        run()
253
    assert run.captured_out == '0123456789'
254
255
256
def test_captured_out_filter(run, capsys):
257
    def print_mock_progress():
258
        sys.stdout.write('progress 0')
259
        sys.stdout.flush()
260
        for i in range(10):
261
            sys.stdout.write('\b')
262
            sys.stdout.write(str(i))
263
            sys.stdout.flush()
264
265
    run.captured_out_filter = apply_backspaces_and_linefeeds
266
    run.main_function.side_effect = print_mock_progress
267
    with capsys.disabled():
268
        run()
269
    assert run.captured_out == 'progress 9'
270