1 | #!/usr/bin/env python |
||
2 | # coding=utf-8 |
||
3 | from __future__ import division, print_function, unicode_literals |
||
4 | |||
5 | import datetime |
||
6 | import os.path |
||
7 | import sys |
||
8 | import traceback as tb |
||
9 | |||
10 | from sacred import metrics_logger |
||
11 | from sacred.metrics_logger import linearize_metrics |
||
12 | from sacred.randomness import set_global_seed |
||
13 | from sacred.utils import SacredInterrupt, join_paths, \ |
||
14 | IntervalTimer |
||
15 | from sacred.stdout_capturing import get_stdcapturer |
||
16 | |||
17 | |||
18 | __sacred__ = True # marks files that should be filtered from stack traces |
||
19 | |||
20 | |||
21 | class Run(object): |
||
22 | """Represent and manage a single run of an experiment.""" |
||
23 | |||
24 | def __init__(self, config, config_modifications, main_function, observers, |
||
25 | root_logger, run_logger, experiment_info, host_info, |
||
26 | pre_run_hooks, post_run_hooks, captured_out_filter=None): |
||
27 | |||
28 | self._id = None |
||
29 | """The ID of this run as assigned by the first observer""" |
||
30 | |||
31 | self.captured_out = '' |
||
32 | """Captured stdout and stderr""" |
||
33 | |||
34 | self.config = config |
||
35 | """The final configuration used for this run""" |
||
36 | |||
37 | self.config_modifications = config_modifications |
||
38 | """A ConfigSummary object with information about config changes""" |
||
39 | |||
40 | self.experiment_info = experiment_info |
||
41 | """A dictionary with information about the experiment""" |
||
42 | |||
43 | self.host_info = host_info |
||
44 | """A dictionary with information about the host""" |
||
45 | |||
46 | self.info = {} |
||
47 | """Custom info dict that will be sent to the observers""" |
||
48 | |||
49 | self.root_logger = root_logger |
||
50 | """The root logger that was used to create all the others""" |
||
51 | |||
52 | self.run_logger = run_logger |
||
53 | """The logger that is used for this run""" |
||
54 | |||
55 | self.main_function = main_function |
||
56 | """The main function that is executed with this run""" |
||
57 | |||
58 | self.observers = observers |
||
59 | """A list of all observers that observe this run""" |
||
60 | |||
61 | self.pre_run_hooks = pre_run_hooks |
||
62 | """List of pre-run hooks (captured functions called before this run)""" |
||
63 | |||
64 | self.post_run_hooks = post_run_hooks |
||
65 | """List of post-run hooks (captured functions called after this run)""" |
||
66 | |||
67 | self.result = None |
||
68 | """The return value of the main function""" |
||
69 | |||
70 | self.status = None |
||
71 | """The current status of the run, from QUEUED to COMPLETED""" |
||
72 | |||
73 | self.start_time = None |
||
74 | """The datetime when this run was started""" |
||
75 | |||
76 | self.stop_time = None |
||
77 | """The datetime when this run stopped""" |
||
78 | |||
79 | self.debug = False |
||
80 | """Determines whether this run is executed in debug mode""" |
||
81 | |||
82 | self.pdb = False |
||
83 | """If true the pdb debugger is automatically started after a failure""" |
||
84 | |||
85 | self.meta_info = {} |
||
86 | """A custom comment for this run""" |
||
87 | |||
88 | self.beat_interval = 10.0 # sec |
||
89 | """The time between two heartbeat events measured in seconds""" |
||
90 | |||
91 | self.unobserved = False |
||
92 | """Indicates whether this run should be unobserved""" |
||
93 | |||
94 | self.force = False |
||
95 | """Disable warnings about suspicious changes""" |
||
96 | |||
97 | self.queue_only = False |
||
98 | """If true then this run will only fire the queued_event and quit""" |
||
99 | |||
100 | self.captured_out_filter = captured_out_filter |
||
101 | """Filter function to be applied to captured output""" |
||
102 | |||
103 | self.fail_trace = None |
||
104 | """A stacktrace, in case the run failed""" |
||
105 | |||
106 | self.capture_mode = None |
||
107 | """Determines the way the stdout/stderr are captured""" |
||
108 | |||
109 | self._heartbeat = None |
||
110 | self._failed_observers = [] |
||
111 | self._output_file = None |
||
112 | |||
113 | self._metrics = metrics_logger.MetricsLogger() |
||
114 | |||
115 | def open_resource(self, filename, mode='r'): |
||
116 | """Open a file and also save it as a resource. |
||
117 | |||
118 | Opens a file, reports it to the observers as a resource, and returns |
||
119 | the opened file. |
||
120 | |||
121 | In Sacred terminology a resource is a file that the experiment needed |
||
122 | to access during a run. In case of a MongoObserver that means making |
||
123 | sure the file is stored in the database (but avoiding duplicates) along |
||
124 | its path and md5 sum. |
||
125 | |||
126 | See also :py:meth:`sacred.Experiment.open_resource`. |
||
127 | |||
128 | Parameters |
||
129 | ---------- |
||
130 | filename : str |
||
131 | name of the file that should be opened |
||
132 | mode : str |
||
133 | mode that file will be open |
||
134 | |||
135 | Returns |
||
136 | ------- |
||
137 | file |
||
138 | the opened file-object |
||
139 | |||
140 | """ |
||
141 | filename = os.path.abspath(filename) |
||
142 | self._emit_resource_added(filename) # TODO: maybe non-blocking? |
||
143 | return open(filename, mode) |
||
144 | |||
145 | def add_resource(self, filename): |
||
146 | """Add a file as a resource. |
||
147 | |||
148 | In Sacred terminology a resource is a file that the experiment needed |
||
149 | to access during a run. In case of a MongoObserver that means making |
||
150 | sure the file is stored in the database (but avoiding duplicates) along |
||
151 | its path and md5 sum. |
||
152 | |||
153 | See also :py:meth:`sacred.Experiment.add_resource`. |
||
154 | |||
155 | Parameters |
||
156 | ---------- |
||
157 | filename : str |
||
158 | name of the file to be stored as a resource |
||
159 | """ |
||
160 | filename = os.path.abspath(filename) |
||
161 | self._emit_resource_added(filename) |
||
162 | |||
163 | def add_artifact(self, filename, name=None): |
||
164 | """Add a file as an artifact. |
||
165 | |||
166 | In Sacred terminology an artifact is a file produced by the experiment |
||
167 | run. In case of a MongoObserver that means storing the file in the |
||
168 | database. |
||
169 | |||
170 | See also :py:meth:`sacred.Experiment.add_artifact`. |
||
171 | |||
172 | Parameters |
||
173 | ---------- |
||
174 | filename : str |
||
175 | name of the file to be stored as artifact |
||
176 | name : str, optional |
||
177 | optionally set the name of the artifact. |
||
178 | Defaults to the filename. |
||
179 | """ |
||
180 | filename = os.path.abspath(filename) |
||
181 | name = os.path.basename(filename) if name is None else name |
||
182 | self._emit_artifact_added(name, filename) |
||
183 | |||
184 | def __call__(self, *args): |
||
185 | r"""Start this run. |
||
186 | |||
187 | Parameters |
||
188 | ---------- |
||
189 | \*args |
||
190 | parameters passed to the main function |
||
191 | |||
192 | Returns |
||
193 | ------- |
||
194 | the return value of the main function |
||
195 | |||
196 | """ |
||
197 | if self.start_time is not None: |
||
198 | raise RuntimeError('A run can only be started once. ' |
||
199 | '(Last start was {})'.format(self.start_time)) |
||
200 | |||
201 | if self.unobserved: |
||
202 | self.observers = [] |
||
203 | else: |
||
204 | self.observers = sorted(self.observers, key=lambda x: -x.priority) |
||
205 | |||
206 | self.warn_if_unobserved() |
||
207 | set_global_seed(self.config['seed']) |
||
208 | |||
209 | if self.capture_mode is None and not self.observers: |
||
210 | capture_mode = "no" |
||
211 | else: |
||
212 | capture_mode = self.capture_mode |
||
213 | capture_mode, capture_stdout = get_stdcapturer(capture_mode) |
||
214 | self.run_logger.debug('Using capture mode "%s"', capture_mode) |
||
215 | |||
216 | if self.queue_only: |
||
217 | self._emit_queued() |
||
218 | return |
||
219 | try: |
||
220 | with capture_stdout() as self._output_file: |
||
221 | self._emit_started() |
||
222 | self._start_heartbeat() |
||
223 | self._execute_pre_run_hooks() |
||
224 | self.result = self.main_function(*args) |
||
225 | self._execute_post_run_hooks() |
||
226 | if self.result is not None: |
||
227 | self.run_logger.info('Result: {}'.format(self.result)) |
||
228 | elapsed_time = self._stop_time() |
||
229 | self.run_logger.info('Completed after %s', elapsed_time) |
||
230 | self._get_captured_output() |
||
231 | self._stop_heartbeat() |
||
232 | self._emit_completed(self.result) |
||
233 | except (SacredInterrupt, KeyboardInterrupt) as e: |
||
234 | self._stop_heartbeat() |
||
235 | status = getattr(e, 'STATUS', 'INTERRUPTED') |
||
236 | self._emit_interrupted(status) |
||
237 | raise |
||
238 | except Exception: |
||
239 | exc_type, exc_value, trace = sys.exc_info() |
||
240 | self._stop_heartbeat() |
||
241 | self._emit_failed(exc_type, exc_value, trace.tb_next) |
||
242 | raise |
||
243 | finally: |
||
244 | self._warn_about_failed_observers() |
||
245 | |||
246 | return self.result |
||
247 | |||
248 | def _get_captured_output(self): |
||
249 | if self._output_file.closed: |
||
250 | return |
||
251 | text = self._output_file.get() |
||
252 | if isinstance(text, bytes): |
||
253 | text = text.decode('utf-8', 'replace') |
||
254 | if self.captured_out: |
||
255 | text = self.captured_out + text |
||
256 | if self.captured_out_filter is not None: |
||
257 | text = self.captured_out_filter(text) |
||
258 | self.captured_out = text |
||
259 | |||
260 | def _start_heartbeat(self): |
||
261 | if self.beat_interval > 0: |
||
262 | self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create( |
||
263 | self._emit_heartbeat, self.beat_interval) |
||
264 | self._heartbeat.start() |
||
265 | |||
266 | def _stop_heartbeat(self): |
||
267 | # only stop if heartbeat was started |
||
268 | if self._heartbeat is not None: |
||
269 | self._stop_heartbeat_event.set() |
||
270 | self._heartbeat.join(2) |
||
271 | |||
272 | def _emit_queued(self): |
||
273 | self.status = 'QUEUED' |
||
274 | queue_time = datetime.datetime.utcnow() |
||
275 | self.meta_info['queue_time'] = queue_time |
||
276 | command = join_paths(self.main_function.prefix, |
||
277 | self.main_function.signature.name) |
||
278 | self.run_logger.info("Queuing-up command '%s'", command) |
||
279 | for observer in self.observers: |
||
280 | if hasattr(observer, 'queued_event'): |
||
281 | _id = observer.queued_event( |
||
282 | ex_info=self.experiment_info, |
||
283 | command=command, |
||
284 | View Code Duplication | host_info=self.host_info, |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
285 | queue_time=queue_time, |
||
286 | config=self.config, |
||
287 | meta_info=self.meta_info, |
||
288 | _id=self._id |
||
289 | ) |
||
290 | if self._id is None: |
||
291 | self._id = _id |
||
292 | # do not catch any exceptions on startup: |
||
293 | # the experiment SHOULD fail if any of the observers fails |
||
294 | |||
295 | if self._id is None: |
||
296 | self.run_logger.info('Queued') |
||
297 | else: |
||
298 | self.run_logger.info('Queued-up run with ID "{}"'.format(self._id)) |
||
299 | |||
300 | def _emit_started(self): |
||
301 | self.status = 'RUNNING' |
||
302 | self.start_time = datetime.datetime.utcnow() |
||
303 | command = join_paths(self.main_function.prefix, |
||
304 | self.main_function.signature.name) |
||
305 | self.run_logger.info("Running command '%s'", command) |
||
306 | for observer in self.observers: |
||
307 | if hasattr(observer, 'started_event'): |
||
308 | _id = observer.started_event( |
||
309 | ex_info=self.experiment_info, |
||
310 | command=command, |
||
311 | host_info=self.host_info, |
||
312 | View Code Duplication | start_time=self.start_time, |
|
0 ignored issues
–
show
|
|||
313 | config=self.config, |
||
314 | meta_info=self.meta_info, |
||
315 | _id=self._id |
||
316 | ) |
||
317 | if self._id is None: |
||
318 | self._id = _id |
||
319 | # do not catch any exceptions on startup: |
||
320 | # the experiment SHOULD fail if any of the observers fails |
||
321 | if self._id is None: |
||
322 | self.run_logger.info('Started') |
||
323 | else: |
||
324 | self.run_logger.info('Started run with ID "{}"'.format(self._id)) |
||
325 | |||
326 | def _emit_heartbeat(self): |
||
327 | beat_time = datetime.datetime.utcnow() |
||
328 | self._get_captured_output() |
||
329 | # Read all measured metrics since last heartbeat |
||
330 | logged_metrics = self._metrics.get_last_metrics() |
||
331 | metrics_by_name = linearize_metrics(logged_metrics) |
||
332 | for observer in self.observers: |
||
333 | self._safe_call(observer, 'log_metrics', |
||
334 | metrics_by_name=metrics_by_name, |
||
335 | info=self.info) |
||
336 | self._safe_call(observer, 'heartbeat_event', |
||
337 | info=self.info, |
||
338 | captured_out=self.captured_out, |
||
339 | beat_time=beat_time, |
||
340 | result=self.result) |
||
341 | |||
342 | def _stop_time(self): |
||
343 | self.stop_time = datetime.datetime.utcnow() |
||
344 | elapsed_time = datetime.timedelta( |
||
345 | seconds=round((self.stop_time - self.start_time).total_seconds())) |
||
346 | return elapsed_time |
||
347 | |||
348 | def _emit_completed(self, result): |
||
349 | self.status = 'COMPLETED' |
||
350 | for observer in self.observers: |
||
351 | self._final_call(observer, 'completed_event', |
||
352 | stop_time=self.stop_time, |
||
353 | result=result) |
||
354 | |||
355 | def _emit_interrupted(self, status): |
||
356 | self.status = status |
||
357 | elapsed_time = self._stop_time() |
||
358 | self.run_logger.warning("Aborted after %s!", elapsed_time) |
||
359 | for observer in self.observers: |
||
360 | self._final_call(observer, 'interrupted_event', |
||
361 | interrupt_time=self.stop_time, |
||
362 | status=status) |
||
363 | |||
364 | def _emit_failed(self, exc_type, exc_value, trace): |
||
365 | self.status = 'FAILED' |
||
366 | elapsed_time = self._stop_time() |
||
367 | self.run_logger.error("Failed after %s!", elapsed_time) |
||
368 | self.fail_trace = tb.format_exception(exc_type, exc_value, trace) |
||
369 | for observer in self.observers: |
||
370 | self._final_call(observer, 'failed_event', |
||
371 | fail_time=self.stop_time, |
||
372 | fail_trace=self.fail_trace) |
||
373 | |||
374 | def _emit_resource_added(self, filename): |
||
375 | for observer in self.observers: |
||
376 | self._safe_call(observer, 'resource_event', filename=filename) |
||
377 | |||
378 | def _emit_artifact_added(self, name, filename): |
||
379 | for observer in self.observers: |
||
380 | self._safe_call(observer, 'artifact_event', |
||
381 | name=name, |
||
382 | filename=filename) |
||
383 | |||
384 | def _safe_call(self, obs, method, **kwargs): |
||
385 | if obs not in self._failed_observers and hasattr(obs, method): |
||
386 | try: |
||
387 | getattr(obs, method)(**kwargs) |
||
388 | except Exception as e: |
||
389 | self._failed_observers.append(obs) |
||
390 | self.run_logger.warning("An error ocurred in the '{}' " |
||
391 | "observer: {}".format(obs, e)) |
||
392 | |||
393 | def _final_call(self, observer, method, **kwargs): |
||
394 | if hasattr(observer, method): |
||
395 | try: |
||
396 | getattr(observer, method)(**kwargs) |
||
397 | except Exception: |
||
398 | # Feels dirty to catch all exceptions, but it is just for |
||
399 | # finishing up, so we don't want one observer to kill the |
||
400 | # others |
||
401 | self.run_logger.error(tb.format_exc()) |
||
402 | |||
403 | def _warn_about_failed_observers(self): |
||
404 | for observer in self._failed_observers: |
||
405 | self.run_logger.warning("The observer '{}' failed at some point " |
||
406 | "during the run.".format(observer)) |
||
407 | |||
408 | def _execute_pre_run_hooks(self): |
||
409 | for pr in self.pre_run_hooks: |
||
410 | pr() |
||
411 | |||
412 | def _execute_post_run_hooks(self): |
||
413 | for pr in self.post_run_hooks: |
||
414 | pr() |
||
415 | |||
416 | def warn_if_unobserved(self): |
||
417 | if not self.observers and not self.debug and not self.unobserved: |
||
418 | self.run_logger.warning("No observers have been added to this run") |
||
419 | |||
420 | def log_scalar(self, metric_name, value, step=None): |
||
421 | """ |
||
422 | Add a new measurement. |
||
423 | |||
424 | The measurement will be processed by the MongoDB observer |
||
425 | during a heartbeat event. |
||
426 | Other observers are not yet supported. |
||
427 | |||
428 | :param metric_name: The name of the metric, e.g. training.loss |
||
429 | :param value: The measured value |
||
430 | :param step: The step number (integer), e.g. the iteration number |
||
431 | If not specified, an internal counter for each metric |
||
432 | is used, incremented by one. |
||
433 | """ |
||
434 | # Method added in change https://github.com/chovanecm/sacred/issues/4 |
||
435 | # The same as Experiment.log_scalar (if something changes, |
||
436 | # update the docstring too!) |
||
437 | |||
438 | return self._metrics.log_scalar_metric(metric_name, value, step) |
||
439 |