Completed
Pull Request — master (#184)
by Martin
44s
created

Run.log_scalar()   A

Complexity

Conditions 1

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 1
c 2
b 0
f 0
dl 0
loc 19
rs 9.4285
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import datetime
6
import os.path
7
import sys
8
import threading
9
import traceback as tb
10
11
from sacred import metrics_logger
12
from sacred.metrics_logger import linearize_metrics
13
from sacred.randomness import set_global_seed
14
from sacred.utils import ObserverError, SacredInterrupt, join_paths
15
from sacred.stdout_capturing import get_stdcapturer, flush
16
17
18
__sacred__ = True  # marks files that should be filtered from stack traces
19
20
21
class Run(object):
22
    """Represent and manage a single run of an experiment."""
23
24
    def __init__(self, config, config_modifications, main_function, observers,
25
                 root_logger, run_logger, experiment_info, host_info,
26
                 pre_run_hooks, post_run_hooks, captured_out_filter=None):
27
28
        self._id = None
29
        """The ID of this run as assigned by the first observer"""
30
31
        self.captured_out = None
32
        """Captured stdout and stderr"""
33
34
        self.captured_out_cursor = 0
35
        """Cursor on captured_out to read by chunks"""
36
37
        self.config = config
38
        """The final configuration used for this run"""
39
40
        self.config_modifications = config_modifications
41
        """A ConfigSummary object with information about config changes"""
42
43
        self.experiment_info = experiment_info
44
        """A dictionary with information about the experiment"""
45
46
        self.host_info = host_info
47
        """A dictionary with information about the host"""
48
49
        self.info = {}
50
        """Custom info dict that will be sent to the observers"""
51
52
        self.root_logger = root_logger
53
        """The root logger that was used to create all the others"""
54
55
        self.run_logger = run_logger
56
        """The logger that is used for this run"""
57
58
        self.main_function = main_function
59
        """The main function that is executed with this run"""
60
61
        self.observers = observers
62
        """A list of all observers that observe this run"""
63
64
        self.pre_run_hooks = pre_run_hooks
65
        """List of pre-run hooks (captured functions called before this run)"""
66
67
        self.post_run_hooks = post_run_hooks
68
        """List of post-run hooks (captured functions called after this run)"""
69
70
        self.result = None
71
        """The return value of the main function"""
72
73
        self.status = None
74
        """The current status of the run, from QUEUED to COMPLETED"""
75
76
        self.start_time = None
77
        """The datetime when this run was started"""
78
79
        self.stop_time = None
80
        """The datetime when this run stopped"""
81
82
        self.debug = False
83
        """Determines whether this run is executed in debug mode"""
84
85
        self.pdb = False
86
        """If true the pdb debugger is automatically started after a failure"""
87
88
        self.meta_info = {}
89
        """A custom comment for this run"""
90
91
        self.beat_interval = 10.0  # sec
92
        """The time between two heartbeat events measured in seconds"""
93
94
        self.unobserved = False
95
        """Indicates whether this run should be unobserved"""
96
97
        self.force = False
98
        """Disable warnings about suspicious changes"""
99
100
        self.queue_only = False
101
        """If true then this run will only fire the queued_event and quit"""
102
103
        self.captured_out_filter = captured_out_filter
104
        """Filter function to be applied to captured output"""
105
106
        self.fail_trace = None
107
        """A stacktrace, in case the run failed"""
108
109
        self.capture_mode = None
110
        """Determines the way the stdout/stderr are captured"""
111
112
        self._heartbeat = None
113
        self._failed_observers = []
114
        self._output_file = None
115
116
        self._metrics = metrics_logger.MetricsLogger()
117
118
    def open_resource(self, filename, mode='r'):
119
        """Open a file and also save it as a resource.
120
121
        Opens a file, reports it to the observers as a resource, and returns
122
        the opened file.
123
124
        In Sacred terminology a resource is a file that the experiment needed
125
        to access during a run. In case of a MongoObserver that means making
126
        sure the file is stored in the database (but avoiding duplicates) along
127
        its path and md5 sum.
128
129
        See also :py:meth:`sacred.Experiment.open_resource`.
130
131
        Parameters
132
        ----------
133
        filename : str
134
            name of the file that should be opened
135
        mode : str
136
            mode that file will be open
137
138
        Returns
139
        -------
140
        file
141
            the opened file-object
142
        """
143
        filename = os.path.abspath(filename)
144
        self._emit_resource_added(filename)  # TODO: maybe non-blocking?
145
        return open(filename, mode)
146
147
    def add_resource(self, filename):
148
        """Add a file as a resource.
149
150
        In Sacred terminology a resource is a file that the experiment needed
151
        to access during a run. In case of a MongoObserver that means making
152
        sure the file is stored in the database (but avoiding duplicates) along
153
        its path and md5 sum.
154
155
        See also :py:meth:`sacred.Experiment.add_resource`.
156
157
        Parameters
158
        ----------
159
        filename : str
160
            name of the file to be stored as a resource
161
        """
162
        filename = os.path.abspath(filename)
163
        self._emit_resource_added(filename)
164
165
    def add_artifact(self, filename, name=None):
166
        """Add a file as an artifact.
167
168
        In Sacred terminology an artifact is a file produced by the experiment
169
        run. In case of a MongoObserver that means storing the file in the
170
        database.
171
172
        See also :py:meth:`sacred.Experiment.add_artifact`.
173
174
        Parameters
175
        ----------
176
        filename : str
177
            name of the file to be stored as artifact
178
        name : str, optional
179
            optionally set the name of the artifact.
180
            Defaults to the filename.
181
        """
182
        filename = os.path.abspath(filename)
183
        name = os.path.basename(filename) if name is None else name
184
        self._emit_artifact_added(name, filename)
185
186
    def __call__(self, *args):
187
        r"""Start this run.
188
189
        Parameters
190
        ----------
191
        \*args
192
            parameters passed to the main function
193
194
        Returns
195
        -------
196
            the return value of the main function
197
        """
198
        if self.start_time is not None:
199
            raise RuntimeError('A run can only be started once. '
200
                               '(Last start was {})'.format(self.start_time))
201
202
        if self.unobserved:
203
            self.observers = []
204
        else:
205
            self.observers = sorted(self.observers, key=lambda x: -x.priority)
206
207
        self.warn_if_unobserved()
208
        set_global_seed(self.config['seed'])
209
210
        if self.capture_mode is None and not self.observers:
211
            capture_mode = "no"
212
        else:
213
            capture_mode = self.capture_mode
214
        capture_mode, capture_stdout = get_stdcapturer(capture_mode)
215
        self.run_logger.debug('Using capture mode "%s"', capture_mode)
216
217
        if self.queue_only:
218
            self._emit_queued()
219
            return
220
        try:
221
            try:
222
                with capture_stdout() as (f, final_out):
223
                    self._output_file = f
224
                    self._emit_started()
225
                    self._start_heartbeat()
226
                    self._execute_pre_run_hooks()
227
                    self.result = self.main_function(*args)
228
                    self._execute_post_run_hooks()
229
                    if self.result is not None:
230
                        self.run_logger.info('Result: {}'.format(self.result))
231
                    elapsed_time = self._stop_time()
232
                    self.run_logger.info('Completed after %s', elapsed_time)
233
                    self._get_captured_output()
234
            finally:
235
                self._get_captured_output()
236
            self._stop_heartbeat()
237
            self._emit_completed(self.result)
238
        except (SacredInterrupt, KeyboardInterrupt) as e:
239
            self._stop_heartbeat()
240
            status = getattr(e, 'STATUS', 'INTERRUPTED')
241
            self._emit_interrupted(status)
242 View Code Duplication
            raise
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
243
        except:
244
            exc_type, exc_value, trace = sys.exc_info()
245
            self._stop_heartbeat()
246
            self._emit_failed(exc_type, exc_value, trace.tb_next)
247
            raise
248
        finally:
249
            self._warn_about_failed_observers()
250
251
        return self.result
252
253
    def _get_captured_output(self):
254
        if self._output_file.closed:
255
            return  # nothing we can do
256
        flush()
257
        self._output_file.flush()
258
        self._output_file.seek(self.captured_out_cursor)
259
        text = self._output_file.read()
260
        if isinstance(text, bytes):
261
            text = text.decode()
262
        self.captured_out_cursor += len(text)
263
        if self.captured_out:
264
            text = self.captured_out + text
265
        if self.captured_out_filter is not None:
266
            text = self.captured_out_filter(text)
267
        self.captured_out = text
268
269 View Code Duplication
    def _start_heartbeat(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
270
        self._emit_heartbeat()
271
        if self.beat_interval > 0:
272
            self._heartbeat = threading.Timer(self.beat_interval,
273
                                              self._start_heartbeat)
274
            self._heartbeat.start()
275
276
    def _stop_heartbeat(self):
277
        if self._heartbeat is not None:
278
            self._heartbeat.cancel()
279
        self._heartbeat = None
280
        self._emit_heartbeat()  # one final beat to flush pending changes
281
282
    def _emit_queued(self):
283
        self.status = 'QUEUED'
284
        queue_time = datetime.datetime.utcnow()
285
        self.meta_info['queue_time'] = queue_time
286
        command = join_paths(self.main_function.prefix,
287
                             self.main_function.signature.name)
288
        self.run_logger.info("Queuing-up command '%s'", command)
289
        for observer in self.observers:
290
            if hasattr(observer, 'queued_event'):
291
                _id = observer.queued_event(
292
                    ex_info=self.experiment_info,
293
                    command=command,
294
                    queue_time=queue_time,
295
                    config=self.config,
296
                    meta_info=self.meta_info,
297
                    _id=self._id
298
                )
299
                if self._id is None:
300
                    self._id = _id
301
                # do not catch any exceptions on startup:
302
                # the experiment SHOULD fail if any of the observers fails
303
304
        if self._id is None:
305
            self.run_logger.info('Queued')
306
        else:
307
            self.run_logger.info('Queued-up run with ID "{}"'.format(self._id))
308
309
    def _emit_started(self):
310
        self.status = 'RUNNING'
311
        self.start_time = datetime.datetime.utcnow()
312
        command = join_paths(self.main_function.prefix,
313
                             self.main_function.signature.name)
314
        self.run_logger.info("Running command '%s'", command)
315
        for observer in self.observers:
316
            if hasattr(observer, 'started_event'):
317
                _id = observer.started_event(
318
                    ex_info=self.experiment_info,
319
                    command=command,
320
                    host_info=self.host_info,
321
                    start_time=self.start_time,
322
                    config=self.config,
323
                    meta_info=self.meta_info,
324
                    _id=self._id
325
                )
326
                if self._id is None:
327
                    self._id = _id
328
                # do not catch any exceptions on startup:
329
                # the experiment SHOULD fail if any of the observers fails
330
        if self._id is None:
331
            self.run_logger.info('Started')
332
        else:
333
            self.run_logger.info('Started run with ID "{}"'.format(self._id))
334
335
    def _emit_heartbeat(self):
336
        beat_time = datetime.datetime.utcnow()
337
        self._get_captured_output()
338
        # Read all measured metrics since last heartbeat
339
        logged_metrics = self._metrics.get_last_metrics()
340
        metrics_by_name = linearize_metrics(logged_metrics)
341
        for observer in self.observers:
342
            self._safe_call(observer, 'log_metrics',
343
                            metrics_by_name=metrics_by_name,
344
                            info=self.info)
345
            self._safe_call(observer, 'heartbeat_event',
346
                            info=self.info,
347
                            captured_out=self.captured_out,
348
                            beat_time=beat_time,
349
                            result=self.result)
350
351
    def _stop_time(self):
352
        self.stop_time = datetime.datetime.utcnow()
353
        elapsed_time = datetime.timedelta(
354
            seconds=round((self.stop_time - self.start_time).total_seconds()))
355
        return elapsed_time
356
357
    def _emit_completed(self, result):
358
        self.status = 'COMPLETED'
359
        for observer in self.observers:
360
            self._final_call(observer, 'completed_event',
361
                             stop_time=self.stop_time,
362
                             result=result)
363
364
    def _emit_interrupted(self, status):
365
        self.status = status
366
        elapsed_time = self._stop_time()
367
        self.run_logger.warning("Aborted after %s!", elapsed_time)
368
        for observer in self.observers:
369
            self._final_call(observer, 'interrupted_event',
370
                             interrupt_time=self.stop_time,
371
                             status=status)
372
373
    def _emit_failed(self, exc_type, exc_value, trace):
374
        self.status = 'FAILED'
375
        elapsed_time = self._stop_time()
376
        self.run_logger.error("Failed after %s!", elapsed_time)
377
        self.fail_trace = tb.format_exception(exc_type, exc_value, trace)
378
        for observer in self.observers:
379
            self._final_call(observer, 'failed_event',
380
                             fail_time=self.stop_time,
381
                             fail_trace=self.fail_trace)
382
383
    def _emit_resource_added(self, filename):
384
        for observer in self.observers:
385
            self._safe_call(observer, 'resource_event', filename=filename)
386
387
    def _emit_artifact_added(self, name, filename):
388
        for observer in self.observers:
389
            self._safe_call(observer, 'artifact_event',
390
                            name=name,
391
                            filename=filename)
392
393
    def _safe_call(self, obs, method, **kwargs):
394
        if obs not in self._failed_observers and hasattr(obs, method):
395
            try:
396
                getattr(obs, method)(**kwargs)
397
            except ObserverError as e:
398
                self._failed_observers.append(obs)
399
                self.run_logger.warning("An error ocurred in the '{}' "
400
                                        "observer: {}".format(obs, e))
401
            except:
402
                self._failed_observers.append(obs)
403
                raise
404
405
    def _final_call(self, observer, method, **kwargs):
406
        if hasattr(observer, method):
407
            try:
408
                getattr(observer, method)(**kwargs)
409
            except Exception:
410
                # Feels dirty to catch all exceptions, but it is just for
411
                # finishing up, so we don't want one observer to kill the
412
                # others
413
                self.run_logger.error(tb.format_exc())
414
415
    def _warn_about_failed_observers(self):
416
        for observer in self._failed_observers:
417
            self.run_logger.warning("The observer '{}' failed at some point "
418
                                    "during the run.".format(observer))
419
420
    def _execute_pre_run_hooks(self):
421
        for pr in self.pre_run_hooks:
422
            pr()
423
424
    def _execute_post_run_hooks(self):
425
        for pr in self.post_run_hooks:
426
            pr()
427
428
    def warn_if_unobserved(self):
429
        if not self.observers and not self.debug and not self.unobserved:
430
            self.run_logger.warning("No observers have been added to this run")
431
432
    def log_scalar(self, metric_name, value, step=None):
433
        """
434
        Add a new measurement.
435
436
        The measurement will be processed by the MongoDB observer
437
        during a heartbeat event.
438
        Other observers are not yet supported.
439
440
        :param metric_name: The name of the metric, e.g. training.loss
441
        :param value: The measured value
442
        :param step: The step number (integer), e.g. the iteration number
443
                    If not specified, an internal counter for each metric
444
                    is used, incremented by one.
445
        """
446
        # Method added in change https://github.com/chovanecm/sacred/issues/4
447
        # The same as Experiment.log_scalar (if something changes,
448
        # update the docstring too!)
449
450
        return self._metrics.log_scalar_metric(metric_name, value, step)
451