Completed
Push — master ( 6c3661...416f46 )
by Klaus
34s
created

sacred/run.py (2 issues)

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import datetime
6
import os.path
7
import sys
8
import traceback as tb
9
10
from sacred import metrics_logger
11
from sacred.metrics_logger import linearize_metrics
12
from sacred.randomness import set_global_seed
13
from sacred.utils import SacredInterrupt, join_paths, \
14
    IntervalTimer
15
from sacred.stdout_capturing import get_stdcapturer
16
17
18
__sacred__ = True  # marks files that should be filtered from stack traces
19
20
21
class Run(object):
22
    """Represent and manage a single run of an experiment."""
23
24
    def __init__(self, config, config_modifications, main_function, observers,
25
                 root_logger, run_logger, experiment_info, host_info,
26
                 pre_run_hooks, post_run_hooks, captured_out_filter=None):
27
28
        self._id = None
29
        """The ID of this run as assigned by the first observer"""
30
31
        self.captured_out = ''
32
        """Captured stdout and stderr"""
33
34
        self.config = config
35
        """The final configuration used for this run"""
36
37
        self.config_modifications = config_modifications
38
        """A ConfigSummary object with information about config changes"""
39
40
        self.experiment_info = experiment_info
41
        """A dictionary with information about the experiment"""
42
43
        self.host_info = host_info
44
        """A dictionary with information about the host"""
45
46
        self.info = {}
47
        """Custom info dict that will be sent to the observers"""
48
49
        self.root_logger = root_logger
50
        """The root logger that was used to create all the others"""
51
52
        self.run_logger = run_logger
53
        """The logger that is used for this run"""
54
55
        self.main_function = main_function
56
        """The main function that is executed with this run"""
57
58
        self.observers = observers
59
        """A list of all observers that observe this run"""
60
61
        self.pre_run_hooks = pre_run_hooks
62
        """List of pre-run hooks (captured functions called before this run)"""
63
64
        self.post_run_hooks = post_run_hooks
65
        """List of post-run hooks (captured functions called after this run)"""
66
67
        self.result = None
68
        """The return value of the main function"""
69
70
        self.status = None
71
        """The current status of the run, from QUEUED to COMPLETED"""
72
73
        self.start_time = None
74
        """The datetime when this run was started"""
75
76
        self.stop_time = None
77
        """The datetime when this run stopped"""
78
79
        self.debug = False
80
        """Determines whether this run is executed in debug mode"""
81
82
        self.pdb = False
83
        """If true the pdb debugger is automatically started after a failure"""
84
85
        self.meta_info = {}
86
        """A custom comment for this run"""
87
88
        self.beat_interval = 10.0  # sec
89
        """The time between two heartbeat events measured in seconds"""
90
91
        self.unobserved = False
92
        """Indicates whether this run should be unobserved"""
93
94
        self.force = False
95
        """Disable warnings about suspicious changes"""
96
97
        self.queue_only = False
98
        """If true then this run will only fire the queued_event and quit"""
99
100
        self.captured_out_filter = captured_out_filter
101
        """Filter function to be applied to captured output"""
102
103
        self.fail_trace = None
104
        """A stacktrace, in case the run failed"""
105
106
        self.capture_mode = None
107
        """Determines the way the stdout/stderr are captured"""
108
109
        self._heartbeat = None
110
        self._failed_observers = []
111
        self._output_file = None
112
113
        self._metrics = metrics_logger.MetricsLogger()
114
115
    def open_resource(self, filename, mode='r'):
116
        """Open a file and also save it as a resource.
117
118
        Opens a file, reports it to the observers as a resource, and returns
119
        the opened file.
120
121
        In Sacred terminology a resource is a file that the experiment needed
122
        to access during a run. In case of a MongoObserver that means making
123
        sure the file is stored in the database (but avoiding duplicates) along
124
        its path and md5 sum.
125
126
        See also :py:meth:`sacred.Experiment.open_resource`.
127
128
        Parameters
129
        ----------
130
        filename : str
131
            name of the file that should be opened
132
        mode : str
133
            mode that file will be open
134
135
        Returns
136
        -------
137
        file
138
            the opened file-object
139
140
        """
141
        filename = os.path.abspath(filename)
142
        self._emit_resource_added(filename)  # TODO: maybe non-blocking?
143
        return open(filename, mode)
144
145
    def add_resource(self, filename):
146
        """Add a file as a resource.
147
148
        In Sacred terminology a resource is a file that the experiment needed
149
        to access during a run. In case of a MongoObserver that means making
150
        sure the file is stored in the database (but avoiding duplicates) along
151
        its path and md5 sum.
152
153
        See also :py:meth:`sacred.Experiment.add_resource`.
154
155
        Parameters
156
        ----------
157
        filename : str
158
            name of the file to be stored as a resource
159
        """
160
        filename = os.path.abspath(filename)
161
        self._emit_resource_added(filename)
162
163
    def add_artifact(self, filename, name=None):
164
        """Add a file as an artifact.
165
166
        In Sacred terminology an artifact is a file produced by the experiment
167
        run. In case of a MongoObserver that means storing the file in the
168
        database.
169
170
        See also :py:meth:`sacred.Experiment.add_artifact`.
171
172
        Parameters
173
        ----------
174
        filename : str
175
            name of the file to be stored as artifact
176
        name : str, optional
177
            optionally set the name of the artifact.
178
            Defaults to the filename.
179
        """
180
        filename = os.path.abspath(filename)
181
        name = os.path.basename(filename) if name is None else name
182
        self._emit_artifact_added(name, filename)
183
184
    def __call__(self, *args):
185
        r"""Start this run.
186
187
        Parameters
188
        ----------
189
        \*args
190
            parameters passed to the main function
191
192
        Returns
193
        -------
194
            the return value of the main function
195
196
        """
197
        if self.start_time is not None:
198
            raise RuntimeError('A run can only be started once. '
199
                               '(Last start was {})'.format(self.start_time))
200
201
        if self.unobserved:
202
            self.observers = []
203
        else:
204
            self.observers = sorted(self.observers, key=lambda x: -x.priority)
205
206
        self.warn_if_unobserved()
207
        set_global_seed(self.config['seed'])
208
209
        if self.capture_mode is None and not self.observers:
210
            capture_mode = "no"
211
        else:
212
            capture_mode = self.capture_mode
213
        capture_mode, capture_stdout = get_stdcapturer(capture_mode)
214
        self.run_logger.debug('Using capture mode "%s"', capture_mode)
215
216
        if self.queue_only:
217
            self._emit_queued()
218
            return
219
        try:
220
            with capture_stdout() as self._output_file:
221
                self._emit_started()
222
                self._start_heartbeat()
223
                self._execute_pre_run_hooks()
224
                self.result = self.main_function(*args)
225
                self._execute_post_run_hooks()
226
                if self.result is not None:
227
                    self.run_logger.info('Result: {}'.format(self.result))
228
                elapsed_time = self._stop_time()
229
                self.run_logger.info('Completed after %s', elapsed_time)
230
                self._get_captured_output()
231
            self._stop_heartbeat()
232
            self._emit_completed(self.result)
233
        except (SacredInterrupt, KeyboardInterrupt) as e:
234
            self._stop_heartbeat()
235
            status = getattr(e, 'STATUS', 'INTERRUPTED')
236
            self._emit_interrupted(status)
237
            raise
238
        except Exception:
239
            exc_type, exc_value, trace = sys.exc_info()
240
            self._stop_heartbeat()
241
            self._emit_failed(exc_type, exc_value, trace.tb_next)
242
            raise
243
        finally:
244
            self._warn_about_failed_observers()
245
246
        return self.result
247
248
    def _get_captured_output(self):
249
        if self._output_file.closed:
250
            return
251
        text = self._output_file.get()
252
        if isinstance(text, bytes):
253
            text = text.decode('utf-8', 'replace')
254
        if self.captured_out:
255
            text = self.captured_out + text
256
        if self.captured_out_filter is not None:
257
            text = self.captured_out_filter(text)
258
        self.captured_out = text
259
260
    def _start_heartbeat(self):
261
        if self.beat_interval > 0:
262
            self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create(
263
                self._emit_heartbeat, self.beat_interval)
264
            self._heartbeat.start()
265
266
    def _stop_heartbeat(self):
267
        # only stop if heartbeat was started
268
        if self._heartbeat is not None:
269
            self._stop_heartbeat_event.set()
270
            self._heartbeat.join(2)
271
272
    def _emit_queued(self):
273
        self.status = 'QUEUED'
274
        queue_time = datetime.datetime.utcnow()
275
        self.meta_info['queue_time'] = queue_time
276
        command = join_paths(self.main_function.prefix,
277
                             self.main_function.signature.name)
278
        self.run_logger.info("Queuing-up command '%s'", command)
279
        for observer in self.observers:
280
            if hasattr(observer, 'queued_event'):
281
                _id = observer.queued_event(
282
                    ex_info=self.experiment_info,
283
                    command=command,
284 View Code Duplication
                    host_info=self.host_info,
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
285
                    queue_time=queue_time,
286
                    config=self.config,
287
                    meta_info=self.meta_info,
288
                    _id=self._id
289
                )
290
                if self._id is None:
291
                    self._id = _id
292
                # do not catch any exceptions on startup:
293
                # the experiment SHOULD fail if any of the observers fails
294
295
        if self._id is None:
296
            self.run_logger.info('Queued')
297
        else:
298
            self.run_logger.info('Queued-up run with ID "{}"'.format(self._id))
299
300
    def _emit_started(self):
301
        self.status = 'RUNNING'
302
        self.start_time = datetime.datetime.utcnow()
303
        command = join_paths(self.main_function.prefix,
304
                             self.main_function.signature.name)
305
        self.run_logger.info("Running command '%s'", command)
306
        for observer in self.observers:
307
            if hasattr(observer, 'started_event'):
308
                _id = observer.started_event(
309
                    ex_info=self.experiment_info,
310
                    command=command,
311
                    host_info=self.host_info,
312 View Code Duplication
                    start_time=self.start_time,
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
313
                    config=self.config,
314
                    meta_info=self.meta_info,
315
                    _id=self._id
316
                )
317
                if self._id is None:
318
                    self._id = _id
319
                # do not catch any exceptions on startup:
320
                # the experiment SHOULD fail if any of the observers fails
321
        if self._id is None:
322
            self.run_logger.info('Started')
323
        else:
324
            self.run_logger.info('Started run with ID "{}"'.format(self._id))
325
326
    def _emit_heartbeat(self):
327
        beat_time = datetime.datetime.utcnow()
328
        self._get_captured_output()
329
        # Read all measured metrics since last heartbeat
330
        logged_metrics = self._metrics.get_last_metrics()
331
        metrics_by_name = linearize_metrics(logged_metrics)
332
        for observer in self.observers:
333
            self._safe_call(observer, 'log_metrics',
334
                            metrics_by_name=metrics_by_name,
335
                            info=self.info)
336
            self._safe_call(observer, 'heartbeat_event',
337
                            info=self.info,
338
                            captured_out=self.captured_out,
339
                            beat_time=beat_time,
340
                            result=self.result)
341
342
    def _stop_time(self):
343
        self.stop_time = datetime.datetime.utcnow()
344
        elapsed_time = datetime.timedelta(
345
            seconds=round((self.stop_time - self.start_time).total_seconds()))
346
        return elapsed_time
347
348
    def _emit_completed(self, result):
349
        self.status = 'COMPLETED'
350
        for observer in self.observers:
351
            self._final_call(observer, 'completed_event',
352
                             stop_time=self.stop_time,
353
                             result=result)
354
355
    def _emit_interrupted(self, status):
356
        self.status = status
357
        elapsed_time = self._stop_time()
358
        self.run_logger.warning("Aborted after %s!", elapsed_time)
359
        for observer in self.observers:
360
            self._final_call(observer, 'interrupted_event',
361
                             interrupt_time=self.stop_time,
362
                             status=status)
363
364
    def _emit_failed(self, exc_type, exc_value, trace):
365
        self.status = 'FAILED'
366
        elapsed_time = self._stop_time()
367
        self.run_logger.error("Failed after %s!", elapsed_time)
368
        self.fail_trace = tb.format_exception(exc_type, exc_value, trace)
369
        for observer in self.observers:
370
            self._final_call(observer, 'failed_event',
371
                             fail_time=self.stop_time,
372
                             fail_trace=self.fail_trace)
373
374
    def _emit_resource_added(self, filename):
375
        for observer in self.observers:
376
            self._safe_call(observer, 'resource_event', filename=filename)
377
378
    def _emit_artifact_added(self, name, filename):
379
        for observer in self.observers:
380
            self._safe_call(observer, 'artifact_event',
381
                            name=name,
382
                            filename=filename)
383
384
    def _safe_call(self, obs, method, **kwargs):
385
        if obs not in self._failed_observers and hasattr(obs, method):
386
            try:
387
                getattr(obs, method)(**kwargs)
388
            except Exception as e:
389
                self._failed_observers.append(obs)
390
                self.run_logger.warning("An error ocurred in the '{}' "
391
                                        "observer: {}".format(obs, e))
392
393
    def _final_call(self, observer, method, **kwargs):
394
        if hasattr(observer, method):
395
            try:
396
                getattr(observer, method)(**kwargs)
397
            except Exception:
398
                # Feels dirty to catch all exceptions, but it is just for
399
                # finishing up, so we don't want one observer to kill the
400
                # others
401
                self.run_logger.error(tb.format_exc())
402
403
    def _warn_about_failed_observers(self):
404
        for observer in self._failed_observers:
405
            self.run_logger.warning("The observer '{}' failed at some point "
406
                                    "during the run.".format(observer))
407
408
    def _execute_pre_run_hooks(self):
409
        for pr in self.pre_run_hooks:
410
            pr()
411
412
    def _execute_post_run_hooks(self):
413
        for pr in self.post_run_hooks:
414
            pr()
415
416
    def warn_if_unobserved(self):
417
        if not self.observers and not self.debug and not self.unobserved:
418
            self.run_logger.warning("No observers have been added to this run")
419
420
    def log_scalar(self, metric_name, value, step=None):
421
        """
422
        Add a new measurement.
423
424
        The measurement will be processed by the MongoDB observer
425
        during a heartbeat event.
426
        Other observers are not yet supported.
427
428
        :param metric_name: The name of the metric, e.g. training.loss
429
        :param value: The measured value
430
        :param step: The step number (integer), e.g. the iteration number
431
                    If not specified, an internal counter for each metric
432
                    is used, incremented by one.
433
        """
434
        # Method added in change https://github.com/chovanecm/sacred/issues/4
435
        # The same as Experiment.log_scalar (if something changes,
436
        # update the docstring too!)
437
438
        return self._metrics.log_scalar_metric(metric_name, value, step)
439