Completed
Push — master ( 576803...5d8e11 )
by Klaus
28s
created

sacred/run.py (2 issues)

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import datetime
6
import os.path
7
import sys
8
import traceback as tb
9
10
from sacred import metrics_logger
11
from sacred.metrics_logger import linearize_metrics
12
from sacred.randomness import set_global_seed
13
from sacred.utils import SacredInterrupt, join_paths, \
14
    IntervalTimer
15
from sacred.stdout_capturing import get_stdcapturer
16
17
18
class Run(object):
19
    """Represent and manage a single run of an experiment."""
20
21
    def __init__(self, config, config_modifications, main_function, observers,
22
                 root_logger, run_logger, experiment_info, host_info,
23
                 pre_run_hooks, post_run_hooks, captured_out_filter=None):
24
25
        self._id = None
26
        """The ID of this run as assigned by the first observer"""
27
28
        self.captured_out = ''
29
        """Captured stdout and stderr"""
30
31
        self.config = config
32
        """The final configuration used for this run"""
33
34
        self.config_modifications = config_modifications
35
        """A ConfigSummary object with information about config changes"""
36
37
        self.experiment_info = experiment_info
38
        """A dictionary with information about the experiment"""
39
40
        self.host_info = host_info
41
        """A dictionary with information about the host"""
42
43
        self.info = {}
44
        """Custom info dict that will be sent to the observers"""
45
46
        self.root_logger = root_logger
47
        """The root logger that was used to create all the others"""
48
49
        self.run_logger = run_logger
50
        """The logger that is used for this run"""
51
52
        self.main_function = main_function
53
        """The main function that is executed with this run"""
54
55
        self.observers = observers
56
        """A list of all observers that observe this run"""
57
58
        self.pre_run_hooks = pre_run_hooks
59
        """List of pre-run hooks (captured functions called before this run)"""
60
61
        self.post_run_hooks = post_run_hooks
62
        """List of post-run hooks (captured functions called after this run)"""
63
64
        self.result = None
65
        """The return value of the main function"""
66
67
        self.status = None
68
        """The current status of the run, from QUEUED to COMPLETED"""
69
70
        self.start_time = None
71
        """The datetime when this run was started"""
72
73
        self.stop_time = None
74
        """The datetime when this run stopped"""
75
76
        self.debug = False
77
        """Determines whether this run is executed in debug mode"""
78
79
        self.pdb = False
80
        """If true the pdb debugger is automatically started after a failure"""
81
82
        self.meta_info = {}
83
        """A custom comment for this run"""
84
85
        self.beat_interval = 10.0  # sec
86
        """The time between two heartbeat events measured in seconds"""
87
88
        self.unobserved = False
89
        """Indicates whether this run should be unobserved"""
90
91
        self.force = False
92
        """Disable warnings about suspicious changes"""
93
94
        self.queue_only = False
95
        """If true then this run will only fire the queued_event and quit"""
96
97
        self.captured_out_filter = captured_out_filter
98
        """Filter function to be applied to captured output"""
99
100
        self.fail_trace = None
101
        """A stacktrace, in case the run failed"""
102
103
        self.capture_mode = None
104
        """Determines the way the stdout/stderr are captured"""
105
106
        self._heartbeat = None
107
        self._failed_observers = []
108
        self._output_file = None
109
110
        self._metrics = metrics_logger.MetricsLogger()
111
112
    def open_resource(self, filename, mode='r'):
113
        """Open a file and also save it as a resource.
114
115
        Opens a file, reports it to the observers as a resource, and returns
116
        the opened file.
117
118
        In Sacred terminology a resource is a file that the experiment needed
119
        to access during a run. In case of a MongoObserver that means making
120
        sure the file is stored in the database (but avoiding duplicates) along
121
        its path and md5 sum.
122
123
        See also :py:meth:`sacred.Experiment.open_resource`.
124
125
        Parameters
126
        ----------
127
        filename : str
128
            name of the file that should be opened
129
        mode : str
130
            mode that file will be open
131
132
        Returns
133
        -------
134
        file
135
            the opened file-object
136
137
        """
138
        filename = os.path.abspath(filename)
139
        self._emit_resource_added(filename)  # TODO: maybe non-blocking?
140
        return open(filename, mode)
141
142
    def add_resource(self, filename):
143
        """Add a file as a resource.
144
145
        In Sacred terminology a resource is a file that the experiment needed
146
        to access during a run. In case of a MongoObserver that means making
147
        sure the file is stored in the database (but avoiding duplicates) along
148
        its path and md5 sum.
149
150
        See also :py:meth:`sacred.Experiment.add_resource`.
151
152
        Parameters
153
        ----------
154
        filename : str
155
            name of the file to be stored as a resource
156
        """
157
        filename = os.path.abspath(filename)
158
        self._emit_resource_added(filename)
159
160
    def add_artifact(self, filename, name=None):
161
        """Add a file as an artifact.
162
163
        In Sacred terminology an artifact is a file produced by the experiment
164
        run. In case of a MongoObserver that means storing the file in the
165
        database.
166
167
        See also :py:meth:`sacred.Experiment.add_artifact`.
168
169
        Parameters
170
        ----------
171
        filename : str
172
            name of the file to be stored as artifact
173
        name : str, optional
174
            optionally set the name of the artifact.
175
            Defaults to the filename.
176
        """
177
        filename = os.path.abspath(filename)
178
        name = os.path.basename(filename) if name is None else name
179
        self._emit_artifact_added(name, filename)
180
181
    def __call__(self, *args):
182
        r"""Start this run.
183
184
        Parameters
185
        ----------
186
        \*args
187
            parameters passed to the main function
188
189
        Returns
190
        -------
191
            the return value of the main function
192
193
        """
194
        if self.start_time is not None:
195
            raise RuntimeError('A run can only be started once. '
196
                               '(Last start was {})'.format(self.start_time))
197
198
        if self.unobserved:
199
            self.observers = []
200
        else:
201
            self.observers = sorted(self.observers, key=lambda x: -x.priority)
202
203
        self.warn_if_unobserved()
204
        set_global_seed(self.config['seed'])
205
206
        if self.capture_mode is None and not self.observers:
207
            capture_mode = "no"
208
        else:
209
            capture_mode = self.capture_mode
210
        capture_mode, capture_stdout = get_stdcapturer(capture_mode)
211
        self.run_logger.debug('Using capture mode "%s"', capture_mode)
212
213
        if self.queue_only:
214
            self._emit_queued()
215
            return
216
        try:
217
            with capture_stdout() as self._output_file:
218
                self._emit_started()
219
                self._start_heartbeat()
220
                self._execute_pre_run_hooks()
221
                self.result = self.main_function(*args)
222
                self._execute_post_run_hooks()
223
                if self.result is not None:
224
                    self.run_logger.info('Result: {}'.format(self.result))
225
                elapsed_time = self._stop_time()
226
                self.run_logger.info('Completed after %s', elapsed_time)
227
                self._get_captured_output()
228
            self._stop_heartbeat()
229
            self._emit_completed(self.result)
230
        except (SacredInterrupt, KeyboardInterrupt) as e:
231
            self._stop_heartbeat()
232
            status = getattr(e, 'STATUS', 'INTERRUPTED')
233
            self._emit_interrupted(status)
234
            raise
235
        except Exception:
236
            exc_type, exc_value, trace = sys.exc_info()
237
            self._stop_heartbeat()
238
            self._emit_failed(exc_type, exc_value, trace.tb_next)
239
            raise
240
        finally:
241
            self._warn_about_failed_observers()
242
243
        return self.result
244
245
    def _get_captured_output(self):
246
        if self._output_file.closed:
247
            return
248
        text = self._output_file.get()
249
        if isinstance(text, bytes):
250
            text = text.decode('utf-8', 'replace')
251
        if self.captured_out:
252
            text = self.captured_out + text
253
        if self.captured_out_filter is not None:
254
            text = self.captured_out_filter(text)
255
        self.captured_out = text
256
257
    def _start_heartbeat(self):
258
        if self.beat_interval > 0:
259
            self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create(
260
                self._emit_heartbeat, self.beat_interval)
261
            self._heartbeat.start()
262
263
    def _stop_heartbeat(self):
264
        # only stop if heartbeat was started
265
        if self._heartbeat is not None:
266
            self._stop_heartbeat_event.set()
267
            self._heartbeat.join(2)
268
269
    def _emit_queued(self):
270
        self.status = 'QUEUED'
271
        queue_time = datetime.datetime.utcnow()
272
        self.meta_info['queue_time'] = queue_time
273
        command = join_paths(self.main_function.prefix,
274
                             self.main_function.signature.name)
275
        self.run_logger.info("Queuing-up command '%s'", command)
276
        for observer in self.observers:
277
            if hasattr(observer, 'queued_event'):
278
                _id = observer.queued_event(
279
                    ex_info=self.experiment_info,
280
                    command=command,
281
                    host_info=self.host_info,
282
                    queue_time=queue_time,
283
                    config=self.config,
284 View Code Duplication
                    meta_info=self.meta_info,
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
285
                    _id=self._id
286
                )
287
                if self._id is None:
288
                    self._id = _id
289
                # do not catch any exceptions on startup:
290
                # the experiment SHOULD fail if any of the observers fails
291
292
        if self._id is None:
293
            self.run_logger.info('Queued')
294
        else:
295
            self.run_logger.info('Queued-up run with ID "{}"'.format(self._id))
296
297
    def _emit_started(self):
298
        self.status = 'RUNNING'
299
        self.start_time = datetime.datetime.utcnow()
300
        command = join_paths(self.main_function.prefix,
301
                             self.main_function.signature.name)
302
        self.run_logger.info("Running command '%s'", command)
303
        for observer in self.observers:
304
            if hasattr(observer, 'started_event'):
305
                _id = observer.started_event(
306
                    ex_info=self.experiment_info,
307
                    command=command,
308
                    host_info=self.host_info,
309
                    start_time=self.start_time,
310
                    config=self.config,
311
                    meta_info=self.meta_info,
312 View Code Duplication
                    _id=self._id
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
313
                )
314
                if self._id is None:
315
                    self._id = _id
316
                # do not catch any exceptions on startup:
317
                # the experiment SHOULD fail if any of the observers fails
318
        if self._id is None:
319
            self.run_logger.info('Started')
320
        else:
321
            self.run_logger.info('Started run with ID "{}"'.format(self._id))
322
323
    def _emit_heartbeat(self):
324
        beat_time = datetime.datetime.utcnow()
325
        self._get_captured_output()
326
        # Read all measured metrics since last heartbeat
327
        logged_metrics = self._metrics.get_last_metrics()
328
        metrics_by_name = linearize_metrics(logged_metrics)
329
        for observer in self.observers:
330
            self._safe_call(observer, 'log_metrics',
331
                            metrics_by_name=metrics_by_name,
332
                            info=self.info)
333
            self._safe_call(observer, 'heartbeat_event',
334
                            info=self.info,
335
                            captured_out=self.captured_out,
336
                            beat_time=beat_time,
337
                            result=self.result)
338
339
    def _stop_time(self):
340
        self.stop_time = datetime.datetime.utcnow()
341
        elapsed_time = datetime.timedelta(
342
            seconds=round((self.stop_time - self.start_time).total_seconds()))
343
        return elapsed_time
344
345
    def _emit_completed(self, result):
346
        self.status = 'COMPLETED'
347
        for observer in self.observers:
348
            self._final_call(observer, 'completed_event',
349
                             stop_time=self.stop_time,
350
                             result=result)
351
352
    def _emit_interrupted(self, status):
353
        self.status = status
354
        elapsed_time = self._stop_time()
355
        self.run_logger.warning("Aborted after %s!", elapsed_time)
356
        for observer in self.observers:
357
            self._final_call(observer, 'interrupted_event',
358
                             interrupt_time=self.stop_time,
359
                             status=status)
360
361
    def _emit_failed(self, exc_type, exc_value, trace):
362
        self.status = 'FAILED'
363
        elapsed_time = self._stop_time()
364
        self.run_logger.error("Failed after %s!", elapsed_time)
365
        self.fail_trace = tb.format_exception(exc_type, exc_value, trace)
366
        for observer in self.observers:
367
            self._final_call(observer, 'failed_event',
368
                             fail_time=self.stop_time,
369
                             fail_trace=self.fail_trace)
370
371
    def _emit_resource_added(self, filename):
372
        for observer in self.observers:
373
            self._safe_call(observer, 'resource_event', filename=filename)
374
375
    def _emit_artifact_added(self, name, filename):
376
        for observer in self.observers:
377
            self._safe_call(observer, 'artifact_event',
378
                            name=name,
379
                            filename=filename)
380
381
    def _safe_call(self, obs, method, **kwargs):
382
        if obs not in self._failed_observers and hasattr(obs, method):
383
            try:
384
                getattr(obs, method)(**kwargs)
385
            except Exception as e:
386
                self._failed_observers.append(obs)
387
                self.run_logger.warning("An error ocurred in the '{}' "
388
                                        "observer: {}".format(obs, e))
389
390
    def _final_call(self, observer, method, **kwargs):
391
        if hasattr(observer, method):
392
            try:
393
                getattr(observer, method)(**kwargs)
394
            except Exception:
395
                # Feels dirty to catch all exceptions, but it is just for
396
                # finishing up, so we don't want one observer to kill the
397
                # others
398
                self.run_logger.error(tb.format_exc())
399
400
    def _warn_about_failed_observers(self):
401
        for observer in self._failed_observers:
402
            self.run_logger.warning("The observer '{}' failed at some point "
403
                                    "during the run.".format(observer))
404
405
    def _execute_pre_run_hooks(self):
406
        for pr in self.pre_run_hooks:
407
            pr()
408
409
    def _execute_post_run_hooks(self):
410
        for pr in self.post_run_hooks:
411
            pr()
412
413
    def warn_if_unobserved(self):
414
        if not self.observers and not self.debug and not self.unobserved:
415
            self.run_logger.warning("No observers have been added to this run")
416
417
    def log_scalar(self, metric_name, value, step=None):
418
        """
419
        Add a new measurement.
420
421
        The measurement will be processed by the MongoDB observer
422
        during a heartbeat event.
423
        Other observers are not yet supported.
424
425
        :param metric_name: The name of the metric, e.g. training.loss
426
        :param value: The measured value
427
        :param step: The step number (integer), e.g. the iteration number
428
                    If not specified, an internal counter for each metric
429
                    is used, incremented by one.
430
        """
431
        # Method added in change https://github.com/chovanecm/sacred/issues/4
432
        # The same as Experiment.log_scalar (if something changes,
433
        # update the docstring too!)
434
435
        return self._metrics.log_scalar_metric(metric_name, value, step)
436