1 | #!/usr/bin/env python |
||
2 | # coding=utf-8 |
||
3 | from __future__ import division, print_function, unicode_literals |
||
4 | from datetime import datetime |
||
5 | import mock |
||
6 | import os |
||
7 | import pytest |
||
8 | import tempfile |
||
9 | import sys |
||
10 | |||
11 | from sacred.run import Run |
||
12 | from sacred.config.config_summary import ConfigSummary |
||
13 | from sacred.utils import (ObserverError, SacredInterrupt, TimeoutInterrupt, |
||
14 | apply_backspaces_and_linefeeds) |
||
15 | |||
16 | |||
17 | @pytest.fixture |
||
18 | def run(): |
||
19 | config = {'a': 17, 'foo': {'bar': True, 'baz': False}, 'seed': 1234} |
||
20 | config_mod = ConfigSummary() |
||
21 | signature = mock.Mock() |
||
22 | signature.name = 'main_func' |
||
23 | main_func = mock.Mock(return_value=123, prefix='', signature=signature) |
||
24 | logger = mock.Mock() |
||
25 | observer = [mock.Mock(priority=10)] |
||
26 | return Run(config, config_mod, main_func, observer, logger, logger, {}, |
||
27 | {}, [], []) |
||
28 | |||
29 | |||
30 | def test_run_attributes(run): |
||
31 | assert isinstance(run.config, dict) |
||
32 | assert isinstance(run.config_modifications, ConfigSummary) |
||
33 | assert isinstance(run.experiment_info, dict) |
||
34 | assert isinstance(run.host_info, dict) |
||
35 | assert isinstance(run.info, dict) |
||
36 | |||
37 | |||
38 | def test_run_state_attributes(run): |
||
39 | assert run.start_time is None |
||
40 | assert run.stop_time is None |
||
41 | assert run.captured_out == '' |
||
42 | assert run.result is None |
||
43 | |||
44 | |||
45 | def test_run_run(run): |
||
46 | assert run() == 123 |
||
47 | assert (run.start_time - datetime.utcnow()).total_seconds() < 1 |
||
48 | assert (run.stop_time - datetime.utcnow()).total_seconds() < 1 |
||
49 | assert run.result == 123 |
||
50 | assert run.captured_out == '' |
||
51 | |||
52 | |||
53 | def test_run_emits_events_if_successful(run): |
||
54 | run() |
||
55 | |||
56 | observer = run.observers[0] |
||
57 | assert observer.started_event.called |
||
58 | assert observer.heartbeat_event.called |
||
59 | assert observer.completed_event.called |
||
60 | assert not observer.interrupted_event.called |
||
61 | assert not observer.failed_event.called |
||
62 | |||
63 | |||
64 | @pytest.mark.parametrize('exception,status', [ |
||
65 | (KeyboardInterrupt, 'INTERRUPTED'), |
||
66 | (SacredInterrupt, 'INTERRUPTED'), |
||
67 | (TimeoutInterrupt, 'TIMEOUT'), |
||
68 | ]) |
||
69 | def test_run_emits_events_if_interrupted(run, exception, status): |
||
70 | observer = run.observers[0] |
||
71 | run.main_function.side_effect = exception |
||
72 | with pytest.raises(exception): |
||
73 | run() |
||
74 | assert observer.started_event.called |
||
75 | assert observer.heartbeat_event.called |
||
76 | assert not observer.completed_event.called |
||
77 | assert observer.interrupted_event.called |
||
78 | observer.interrupted_event.assert_called_with( |
||
79 | interrupt_time=run.stop_time, |
||
80 | status=status) |
||
81 | assert not observer.failed_event.called |
||
82 | |||
83 | |||
84 | def test_run_emits_events_if_failed(run): |
||
85 | observer = run.observers[0] |
||
86 | run.main_function.side_effect = TypeError |
||
87 | with pytest.raises(TypeError): |
||
88 | run() |
||
89 | assert observer.started_event.called |
||
90 | assert observer.heartbeat_event.called |
||
91 | assert not observer.completed_event.called |
||
92 | assert not observer.interrupted_event.called |
||
93 | assert observer.failed_event.called |
||
94 | |||
95 | |||
96 | def test_run_started_event(run): |
||
97 | observer = run.observers[0] |
||
98 | run() |
||
99 | observer.started_event.assert_called_with( |
||
100 | command='main_func', |
||
101 | ex_info=run.experiment_info, |
||
102 | host_info=run.host_info, |
||
103 | start_time=run.start_time, |
||
104 | config=run.config, |
||
105 | meta_info={}, |
||
106 | _id=None |
||
107 | ) |
||
108 | |||
109 | |||
110 | def test_run_completed_event(run): |
||
111 | observer = run.observers[0] |
||
112 | run() |
||
113 | observer.completed_event.assert_called_with( |
||
114 | stop_time=run.stop_time, |
||
115 | result=run.result |
||
116 | ) |
||
117 | |||
118 | |||
119 | def test_run_heartbeat_event(run): |
||
120 | observer = run.observers[0] |
||
121 | run.info['test'] = 321 |
||
122 | run() |
||
123 | call_args, call_kwargs = observer.heartbeat_event.call_args_list[0] |
||
124 | assert call_kwargs['info'] == run.info |
||
125 | assert call_kwargs['captured_out'] == "" |
||
126 | assert (call_kwargs['beat_time'] - datetime.utcnow()).total_seconds() < 1 |
||
127 | |||
128 | |||
129 | def test_run_artifact_event(run): |
||
130 | observer = run.observers[0] |
||
131 | handle, f_name = tempfile.mkstemp() |
||
132 | run.add_artifact(f_name, name='foobar') |
||
133 | observer.artifact_event.assert_called_with(filename=f_name, name='foobar') |
||
134 | os.close(handle) |
||
135 | os.remove(f_name) |
||
136 | |||
137 | |||
138 | def test_run_resource_event(run): |
||
139 | observer = run.observers[0] |
||
140 | handle, f_name = tempfile.mkstemp() |
||
141 | run.open_resource(f_name) |
||
142 | observer.resource_event.assert_called_with(filename=f_name) |
||
143 | os.close(handle) |
||
144 | os.remove(f_name) |
||
145 | |||
146 | |||
147 | def test_run_cannot_be_started_twice(run): |
||
148 | run() |
||
149 | with pytest.raises(RuntimeError): |
||
150 | run() |
||
151 | |||
152 | |||
153 | def test_run_observer_failure_on_startup_not_caught(run): |
||
154 | observer = run.observers[0] |
||
155 | observer.started_event.side_effect = ObserverError |
||
156 | with pytest.raises(ObserverError): |
||
157 | run() |
||
158 | |||
159 | |||
160 | def test_run_observer_error_in_heartbeat_is_caught(run): |
||
161 | observer = run.observers[0] |
||
162 | observer.heartbeat_event.side_effect = TypeError |
||
163 | run() |
||
164 | assert observer in run._failed_observers |
||
165 | assert observer.started_event.called |
||
166 | assert observer.heartbeat_event.called |
||
167 | assert observer.completed_event.called |
||
168 | |||
169 | |||
170 | def test_run_exception_in_completed_event_is_caught(run): |
||
171 | observer = run.observers[0] |
||
172 | observer2 = mock.Mock(priority=20) |
||
173 | run.observers.append(observer2) |
||
174 | observer.completed_event.side_effect = TypeError |
||
175 | run() |
||
176 | assert observer.completed_event.called |
||
177 | assert observer2.completed_event.called |
||
178 | |||
179 | |||
180 | def test_run_exception_in_interrupted_event_is_caught(run): |
||
181 | observer = run.observers[0] |
||
182 | observer2 = mock.Mock(priority=20) |
||
183 | run.observers.append(observer2) |
||
184 | observer.interrupted_event.side_effect = TypeError |
||
185 | run.main_function.side_effect = KeyboardInterrupt |
||
186 | with pytest.raises(KeyboardInterrupt): |
||
187 | run() |
||
188 | assert observer.interrupted_event.called |
||
189 | assert observer2.interrupted_event.called |
||
190 | |||
191 | |||
192 | def test_run_exception_in_failed_event_is_caught(run): |
||
193 | View Code Duplication | observer = run.observers[0] |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
194 | observer2 = mock.Mock(priority=20) |
||
195 | run.observers.append(observer2) |
||
196 | observer.failed_event.side_effect = TypeError |
||
197 | run.main_function.side_effect = AttributeError |
||
198 | with pytest.raises(AttributeError): |
||
199 | run() |
||
200 | assert observer.failed_event.called |
||
201 | assert observer2.failed_event.called |
||
202 | |||
203 | |||
204 | def test_unobserved_run_doesnt_emit(run): |
||
205 | View Code Duplication | observer = run.observers[0] |
|
0 ignored issues
–
show
|
|||
206 | run.unobserved = True |
||
207 | run() |
||
208 | assert not observer.started_event.called |
||
209 | assert not observer.heartbeat_event.called |
||
210 | assert not observer.completed_event.called |
||
211 | assert not observer.interrupted_event.called |
||
212 | assert not observer.failed_event.called |
||
213 | |||
214 | |||
215 | def test_stdout_capturing_no(run, capsys): |
||
216 | def print_mock_progress(): |
||
217 | for i in range(10): |
||
218 | print(i, end="") |
||
219 | sys.stdout.flush() |
||
220 | |||
221 | run.main_function.side_effect = print_mock_progress |
||
222 | run.capture_mode = "no" |
||
223 | with capsys.disabled(): |
||
224 | run() |
||
225 | assert run.captured_out == '' |
||
226 | |||
227 | |||
228 | View Code Duplication | def test_stdout_capturing_sys(run, capsys): |
|
0 ignored issues
–
show
|
|||
229 | def print_mock_progress(): |
||
230 | for i in range(10): |
||
231 | print(i, end="") |
||
232 | sys.stdout.flush() |
||
233 | |||
234 | run.main_function.side_effect = print_mock_progress |
||
235 | run.capture_mode = "sys" |
||
236 | with capsys.disabled(): |
||
237 | run() |
||
238 | assert run.captured_out == '0123456789' |
||
239 | |||
240 | |||
241 | @pytest.mark.skipif(sys.platform.startswith('win'), |
||
242 | reason="does not work on windows") |
||
243 | def test_stdout_capturing_fd(run, capsys): |
||
244 | def print_mock_progress(): |
||
245 | for i in range(10): |
||
246 | print(i, end="") |
||
247 | sys.stdout.flush() |
||
248 | |||
249 | run.main_function.side_effect = print_mock_progress |
||
250 | run.capture_mode = "fd" |
||
251 | with capsys.disabled(): |
||
252 | run() |
||
253 | assert run.captured_out == '0123456789' |
||
254 | |||
255 | |||
256 | def test_captured_out_filter(run, capsys): |
||
257 | def print_mock_progress(): |
||
258 | sys.stdout.write('progress 0') |
||
259 | sys.stdout.flush() |
||
260 | for i in range(10): |
||
261 | sys.stdout.write('\b') |
||
262 | sys.stdout.write(str(i)) |
||
263 | sys.stdout.flush() |
||
264 | |||
265 | run.captured_out_filter = apply_backspaces_and_linefeeds |
||
266 | run.main_function.side_effect = print_mock_progress |
||
267 | with capsys.disabled(): |
||
268 | run() |
||
269 | assert run.captured_out == 'progress 9' |
||
270 |