Completed
Push — master ( 3e9443...e3cc93 )
by Klaus
9s
created

Run.add_resource()   A

Complexity

Conditions 1

Size

Total Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 17
rs 9.4285
1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import datetime
6
import os.path
7
import sys
8
import threading
9
import traceback as tb
10
11
from tempfile import NamedTemporaryFile
12
13
from sacred.randomness import set_global_seed
14
from sacred.utils import (tee_output, ObserverError, SacredInterrupt,
15
                          join_paths, flush)
16
17
18
__sacred__ = True  # marks files that should be filtered from stack traces
19
20
21
class Run(object):
22
    """Represent and manage a single run of an experiment."""
23
24
    def __init__(self, config, config_modifications, main_function, observers,
25
                 root_logger, run_logger, experiment_info, host_info,
26
                 pre_run_hooks, post_run_hooks, captured_out_filter=None):
27
28
        self._id = None
29
        """The ID of this run as assigned by the first observer"""
30
31
        self.captured_out = None
32
        """Captured stdout and stderr"""
33
34
        self.config = config
35
        """The final configuration used for this run"""
36
37
        self.config_modifications = config_modifications
38
        """A ConfigSummary object with information about config changes"""
39
40
        self.experiment_info = experiment_info
41
        """A dictionary with information about the experiment"""
42
43
        self.host_info = host_info
44
        """A dictionary with information about the host"""
45
46
        self.info = {}
47
        """Custom info dict that will be sent to the observers"""
48
49
        self.root_logger = root_logger
50
        """The root logger that was used to create all the others"""
51
52
        self.run_logger = run_logger
53
        """The logger that is used for this run"""
54
55
        self.main_function = main_function
56
        """The main function that is executed with this run"""
57
58
        self.observers = observers
59
        """A list of all observers that observe this run"""
60
61
        self.pre_run_hooks = pre_run_hooks
62
        """List of pre-run hooks (captured functions called before this run)"""
63
64
        self.post_run_hooks = post_run_hooks
65
        """List of post-run hooks (captured functions called after this run)"""
66
67
        self.result = None
68
        """The return value of the main function"""
69
70
        self.status = None
71
        """The current status of the run, from QUEUED to COMPLETED"""
72
73
        self.start_time = None
74
        """The datetime when this run was started"""
75
76
        self.stop_time = None
77
        """The datetime when this run stopped"""
78
79
        self.debug = False
80
        """Determines whether this run is executed in debug mode"""
81
82
        self.pdb = False
83
        """If true the pdb debugger is automatically started after a failure"""
84
85
        self.meta_info = {}
86
        """A custom comment for this run"""
87
88
        self.beat_interval = 10.0  # sec
89
        """The time between two heartbeat events measured in seconds"""
90
91
        self.unobserved = False
92
        """Indicates whether this run should be unobserved"""
93
94
        self.force = False
95
        """Disable warnings about suspicious changes"""
96
97
        self.queue_only = False
98
        """If true then this run will only fire the queued_event and quit"""
99
100
        self.captured_out_filter = captured_out_filter
101
        """Filter function to be applied to captured output"""
102
103
        self.fail_trace = None
104
        """A stacktrace, in case the run failed"""
105
106
        self._heartbeat = None
107
        self._failed_observers = []
108
        self._output_file = None
109
110
    def open_resource(self, filename, mode='r'):
111
        """Open a file and also save it as a resource.
112
113
        Opens a file, reports it to the observers as a resource, and returns
114
        the opened file.
115
116
        In Sacred terminology a resource is a file that the experiment needed
117
        to access during a run. In case of a MongoObserver that means making
118
        sure the file is stored in the database (but avoiding duplicates) along
119
        its path and md5 sum.
120
121
        See also :py:meth:`sacred.Experiment.open_resource`.
122
123
        Parameters
124
        ----------
125
        filename : str
126
            name of the file that should be opened
127
        mode : str
128
            mode that file will be open
129
130
        Returns
131
        -------
132
        file
133
            the opened file-object
134
        """
135
        filename = os.path.abspath(filename)
136
        self._emit_resource_added(filename)  # TODO: maybe non-blocking?
137
        return open(filename, mode)
138
139
    def add_resource(self, filename):
140
        """Add a file as a resource.
141
142
        In Sacred terminology a resource is a file that the experiment needed
143
        to access during a run. In case of a MongoObserver that means making
144
        sure the file is stored in the database (but avoiding duplicates) along
145
        its path and md5 sum.
146
147
        See also :py:meth:`sacred.Experiment.add_resource`.
148
149
        Parameters
150
        ----------
151
        filename : str
152
            name of the file to be stored as a resource
153
        """
154
        filename = os.path.abspath(filename)
155
        self._emit_resource_added(filename)
156
157
    def add_artifact(self, filename, name=None):
158
        """Add a file as an artifact.
159
160
        In Sacred terminology an artifact is a file produced by the experiment
161
        run. In case of a MongoObserver that means storing the file in the
162
        database.
163
164
        See also :py:meth:`sacred.Experiment.add_artifact`.
165
166
        Parameters
167
        ----------
168
        filename : str
169
            name of the file to be stored as artifact
170
        name : str, optional
171
            optionally set the name of the artifact.
172
            Defaults to the relative file-path.
173
        """
174
        filename = os.path.abspath(filename)
175
        name = os.path.relpath(filename) if name is None else name
176
        self._emit_artifact_added(name, filename)
177
178
    def __call__(self, *args):
179
        r"""Start this run.
180
181
        Parameters
182
        ----------
183
        \*args
184
            parameters passed to the main function
185
186
        Returns
187
        -------
188
            the return value of the main function
189
        """
190
        if self.start_time is not None:
191
            raise RuntimeError('A run can only be started once. '
192
                               '(Last start was {})'.format(self.start_time))
193
194
        if self.unobserved:
195
            self.observers = []
196
        else:
197
            self.observers = sorted(self.observers, key=lambda x: -x.priority)
198
199
        self.warn_if_unobserved()
200
        set_global_seed(self.config['seed'])
201
202
        if self.queue_only:
203
            self._emit_queued()
204
            return
205
        try:
206
            try:
207
                with NamedTemporaryFile() as f, tee_output(f) as final_out:
208
                    self._output_file = f
209
                    self._emit_started()
210
                    self._start_heartbeat()
211
                    self._execute_pre_run_hooks()
212
                    self.result = self.main_function(*args)
213
                    self._execute_post_run_hooks()
214
                    if self.result is not None:
215
                        self.run_logger.info('Result: {}'.format(self.result))
216
                    elapsed_time = self._stop_time()
217
                    self.run_logger.info('Completed after %s', elapsed_time)
218
            finally:
219
                self.captured_out = final_out[0]
220
                if self.captured_out_filter is not None:
221
                    self.captured_out = self.captured_out_filter(
222
                        self.captured_out)
223
            self._stop_heartbeat()
224
            self._emit_completed(self.result)
225
        except (SacredInterrupt, KeyboardInterrupt) as e:
226
            self._stop_heartbeat()
227
            status = getattr(e, 'STATUS', 'INTERRUPTED')
228
            self._emit_interrupted(status)
229
            raise
230
        except:
231
            exc_type, exc_value, trace = sys.exc_info()
232
            self._stop_heartbeat()
233
            self._emit_failed(exc_type, exc_value, trace.tb_next)
234
            raise
235
        finally:
236
            self._warn_about_failed_observers()
237
238
        return self.result
239
240
    def _get_captured_output(self):
241
        if self._output_file.closed:
242 View Code Duplication
            return  # nothing we can do
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
243
        flush()
244
        self._output_file.flush()
245
        self._output_file.seek(0)
246
        text = self._output_file.read().decode()
247
        if self.captured_out_filter is not None:
248
            text = self.captured_out_filter(text)
249
        self.captured_out = text
250
251
    def _start_heartbeat(self):
252
        self._emit_heartbeat()
253
        if self.beat_interval > 0:
254
            self._heartbeat = threading.Timer(self.beat_interval,
255
                                              self._start_heartbeat)
256
            self._heartbeat.start()
257
258
    def _stop_heartbeat(self):
259
        if self._heartbeat is not None:
260
            self._heartbeat.cancel()
261
        self._heartbeat = None
262
        self._emit_heartbeat()  # one final beat to flush pending changes
263
264
    def _emit_queued(self):
265
        self.status = 'QUEUED'
266
        queue_time = datetime.datetime.utcnow()
267
        self.meta_info['queue_time'] = queue_time
268
        command = join_paths(self.main_function.prefix,
269 View Code Duplication
                             self.main_function.signature.name)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
270
        self.run_logger.info("Queuing-up command '%s'", command)
271
        for observer in self.observers:
272
            if hasattr(observer, 'queued_event'):
273
                _id = observer.queued_event(
274
                    ex_info=self.experiment_info,
275
                    command=command,
276
                    queue_time=queue_time,
277
                    config=self.config,
278
                    meta_info=self.meta_info,
279
                    _id=self._id
280
                )
281
                if self._id is None:
282
                    self._id = _id
283
                # do not catch any exceptions on startup:
284
                # the experiment SHOULD fail if any of the observers fails
285
286
        if self._id is None:
287
            self.run_logger.info('Queued')
288
        else:
289
            self.run_logger.info('Queued-up run with ID "{}"'.format(self._id))
290
291
    def _emit_started(self):
292
        self.status = 'RUNNING'
293
        self.start_time = datetime.datetime.utcnow()
294
        command = join_paths(self.main_function.prefix,
295
                             self.main_function.signature.name)
296
        self.run_logger.info("Running command '%s'", command)
297
        for observer in self.observers:
298
            if hasattr(observer, 'started_event'):
299
                _id = observer.started_event(
300
                    ex_info=self.experiment_info,
301
                    command=command,
302
                    host_info=self.host_info,
303
                    start_time=self.start_time,
304
                    config=self.config,
305
                    meta_info=self.meta_info,
306
                    _id=self._id
307
                )
308
                if self._id is None:
309
                    self._id = _id
310
                # do not catch any exceptions on startup:
311
                # the experiment SHOULD fail if any of the observers fails
312
        if self._id is None:
313
            self.run_logger.info('Started')
314
        else:
315
            self.run_logger.info('Started run with ID "{}"'.format(self._id))
316
317
    def _emit_heartbeat(self):
318
        beat_time = datetime.datetime.utcnow()
319
        self._get_captured_output()
320
        for observer in self.observers:
321
            self._safe_call(observer, 'heartbeat_event',
322
                            info=self.info,
323
                            captured_out=self.captured_out,
324
                            beat_time=beat_time)
325
326
    def _stop_time(self):
327
        self.stop_time = datetime.datetime.utcnow()
328
        elapsed_time = datetime.timedelta(
329
            seconds=round((self.stop_time - self.start_time).total_seconds()))
330
        return elapsed_time
331
332
    def _emit_completed(self, result):
333
        self.status = 'COMPLETED'
334
        for observer in self.observers:
335
            self._final_call(observer, 'completed_event',
336
                             stop_time=self.stop_time,
337
                             result=result)
338
339
    def _emit_interrupted(self, status):
340
        self.status = status
341
        elapsed_time = self._stop_time()
342
        self.run_logger.warning("Aborted after %s!", elapsed_time)
343
        for observer in self.observers:
344
            self._final_call(observer, 'interrupted_event',
345
                             interrupt_time=self.stop_time,
346
                             status=status)
347
348
    def _emit_failed(self, exc_type, exc_value, trace):
349
        self.status = 'FAILED'
350
        elapsed_time = self._stop_time()
351
        self.run_logger.error("Failed after %s!", elapsed_time)
352
        self.fail_trace = tb.format_exception(exc_type, exc_value, trace)
353
        for observer in self.observers:
354
            self._final_call(observer, 'failed_event',
355
                             fail_time=self.stop_time,
356
                             fail_trace=self.fail_trace)
357
358
    def _emit_resource_added(self, filename):
359
        for observer in self.observers:
360
            self._safe_call(observer, 'resource_event', filename=filename)
361
362
    def _emit_artifact_added(self, name, filename):
363
        for observer in self.observers:
364
            self._safe_call(observer, 'artifact_event',
365
                            name=name,
366
                            filename=filename)
367
368
    def _safe_call(self, obs, method, **kwargs):
369
        if obs not in self._failed_observers and hasattr(obs, method):
370
            try:
371
                getattr(obs, method)(**kwargs)
372
            except ObserverError as e:
373
                self._failed_observers.append(obs)
374
                self.run_logger.warning("An error ocurred in the '{}' "
375
                                        "observer: {}".format(obs, e))
376
            except:
377
                self._failed_observers.append(obs)
378
                raise
379
380
    def _final_call(self, observer, method, **kwargs):
381
        if hasattr(observer, method):
382
            try:
383
                getattr(observer, method)(**kwargs)
384
            except Exception:
385
                # Feels dirty to catch all exceptions, but it is just for
386
                # finishing up, so we don't want one observer to kill the
387
                # others
388
                self.run_logger.error(tb.format_exc())
389
390
    def _warn_about_failed_observers(self):
391
        for observer in self._failed_observers:
392
            self.run_logger.warning("The observer '{}' failed at some point "
393
                                    "during the run.".format(observer))
394
395
    def _execute_pre_run_hooks(self):
396
        for pr in self.pre_run_hooks:
397
            pr()
398
399
    def _execute_post_run_hooks(self):
400
        for pr in self.post_run_hooks:
401
            pr()
402
403
    def warn_if_unobserved(self):
404
        if not self.observers and not self.debug and not self.unobserved:
405
            self.run_logger.warning("No observers have been added to this run")
406