Completed
Push — master ( dbc38f...56accc )
by Klaus
01:34
created

Run.__init__()   B

Complexity

Conditions 1

Size

Total Lines 85

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 1 Features 0
Metric Value
cc 1
dl 0
loc 85
rs 8.6875
c 2
b 1
f 0

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
#!/usr/bin/env python
2
# coding=utf-8
3
from __future__ import division, print_function, unicode_literals
4
5
import datetime
6
import os.path
7
import sys
8
import threading
9
import traceback as tb
10
11
from tempfile import NamedTemporaryFile
12
13
from sacred.randomness import set_global_seed
14
from sacred.utils import (tee_output, ObserverError, SacredInterrupt,
15
                          join_paths, flush)
16
17
18
__sacred__ = True  # marks files that should be filtered from stack traces
19
20
21
class Run(object):
22
    """Represent and manage a single run of an experiment."""
23
24
    def __init__(self, config, config_modifications, main_function, observers,
25
                 root_logger, run_logger, experiment_info, host_info,
26
                 pre_run_hooks, post_run_hooks, captured_out_filter=None):
27
28
        self._id = None
29
        """The ID of this run as assigned by the first observer"""
30
31
        self.captured_out = None
32
        """Captured stdout and stderr"""
33
34
        self.config = config
35
        """The final configuration used for this run"""
36
37
        self.config_modifications = config_modifications
38
        """A ConfigSummary object with information about config changes"""
39
40
        self.experiment_info = experiment_info
41
        """A dictionary with information about the experiment"""
42
43
        self.host_info = host_info
44
        """A dictionary with information about the host"""
45
46
        self.info = {}
47
        """Custom info dict that will be sent to the observers"""
48
49
        self.root_logger = root_logger
50
        """The root logger that was used to create all the others"""
51
52
        self.run_logger = run_logger
53
        """The logger that is used for this run"""
54
55
        self.main_function = main_function
56
        """The main function that is executed with this run"""
57
58
        self.observers = observers
59
        """A list of all observers that observe this run"""
60
61
        self.pre_run_hooks = pre_run_hooks
62
        """List of pre-run hooks (captured functions called before this run)"""
63
64
        self.post_run_hooks = post_run_hooks
65
        """List of post-run hooks (captured functions called after this run)"""
66
67
        self.result = None
68
        """The return value of the main function"""
69
70
        self.status = None
71
        """The current status of the run, from QUEUED to COMPLETED"""
72
73
        self.start_time = None
74
        """The datetime when this run was started"""
75
76
        self.stop_time = None
77
        """The datetime when this run stopped"""
78
79
        self.debug = False
80
        """Determines whether this run is executed in debug mode"""
81
82
        self.pdb = False
83
        """If true the pdb debugger is automatically started after a failure"""
84
85
        self.meta_info = {}
86
        """A custom comment for this run"""
87
88
        self.beat_interval = 10.0  # sec
89
        """The time between two heartbeat events measured in seconds"""
90
91
        self.unobserved = False
92
        """Indicates whether this run should be unobserved"""
93
94
        self.force = False
95
        """Disable warnings about suspicious changes"""
96
97
        self.queue_only = False
98
        """If true then this run will only fire the queued_event and quit"""
99
100
        self.captured_out_filter = captured_out_filter
101
        """Filter function to be applied to captured output"""
102
103
        self.fail_trace = None
104
        """A stacktrace, in case the run failed"""
105
106
        self._heartbeat = None
107
        self._failed_observers = []
108
        self._output_file = None
109
110
    def open_resource(self, filename):
111
        """Open a file and also save it as a resource.
112
113
        Opens a file, reports it to the observers as a resource, and returns
114
        the opened file.
115
116
        In Sacred terminology a resource is a file that the experiment needed
117
        to access during a run. In case of a MongoObserver that means making
118
        sure the file is stored in the database (but avoiding duplicates) along
119
        its path and md5 sum.
120
121
        See also :py:meth:`sacred.Experiment.open_resource`.
122
123
        Parameters
124
        ----------
125
        filename : str
126
            name of the file that should be opened
127
128
        Returns
129
        -------
130
        file
131
            the opened file-object
132
        """
133
        filename = os.path.abspath(filename)
134
        self._emit_resource_added(filename)  # TODO: maybe non-blocking?
135
        return open(filename, 'r')  # TODO: How to deal with binary mode?
136
137
    def add_artifact(self, filename, name=None):
138
        """Add a file as an artifact.
139
140
        In Sacred terminology an artifact is a file produced by the experiment
141
        run. In case of a MongoObserver that means storing the file in the
142
        database.
143
144
        See also :py:meth:`sacred.Experiment.add_artifact`.
145
146
        Parameters
147
        ----------
148
        filename : str
149
            name of the file to be stored as artifact
150
        name : str, optional
151
            optionally set the name of the artifact.
152
            Defaults to the relative file-path.
153
        """
154
        filename = os.path.abspath(filename)
155
        name = os.path.relpath(filename) if name is None else name
156
        self._emit_artifact_added(name, filename)
157
158
    def __call__(self, *args):
159
        r"""Start this run.
160
161
        Parameters
162
        ----------
163
        \*args
164
            parameters passed to the main function
165
166
        Returns
167
        -------
168
            the return value of the main function
169
        """
170
        if self.start_time is not None:
171
            raise RuntimeError('A run can only be started once. '
172
                               '(Last start was {})'.format(self.start_time))
173
174
        if self.unobserved:
175
            self.observers = []
176
177
        self.warn_if_unobserved()
178
        set_global_seed(self.config['seed'])
179
180
        if self.queue_only:
181
            self._emit_queued()
182
            return
183
        try:
184
            try:
185
                with NamedTemporaryFile() as f, tee_output(f) as final_out:
186
                    self._output_file = f
187
                    self._emit_started()
188
                    self._start_heartbeat()
189
                    self._execute_pre_run_hooks()
190
                    self.result = self.main_function(*args)
191
                    self._execute_post_run_hooks()
192
                    if self.result is not None:
193
                        self.run_logger.info('Result: {}'.format(self.result))
194
                    elapsed_time = self._stop_time()
195
                    self.run_logger.info('Completed after %s', elapsed_time)
196
            finally:
197
                self.captured_out = final_out[0]
198
                if self.captured_out_filter is not None:
199
                    self.captured_out = self.captured_out_filter(
200
                        self.captured_out)
201
            self._stop_heartbeat()
202
            self._emit_completed(self.result)
203
        except (SacredInterrupt, KeyboardInterrupt) as e:
204
            self._stop_heartbeat()
205
            status = getattr(e, 'STATUS', 'INTERRUPTED')
206
            self._emit_interrupted(status)
207
            raise
208
        except:
209
            exc_type, exc_value, trace = sys.exc_info()
210
            self._stop_heartbeat()
211
            self._emit_failed(exc_type, exc_value, trace.tb_next)
212
            raise
213
        finally:
214
            self._warn_about_failed_observers()
215
216
        return self.result
217
218
    def _get_captured_output(self):
219
        if self._output_file.closed:
220
            return  # nothing we can do
221
        flush()
222
        self._output_file.flush()
223
        self._output_file.seek(0)
224
        text = self._output_file.read().decode()
225
        if self.captured_out_filter is not None:
226
            text = self.captured_out_filter(text)
227
        self.captured_out = text
228
229
    def _start_heartbeat(self):
230
        self._emit_heartbeat()
231
        if self.beat_interval > 0:
232
            self._heartbeat = threading.Timer(self.beat_interval,
233
                                              self._start_heartbeat)
234
            self._heartbeat.start()
235
236
    def _stop_heartbeat(self):
237
        if self._heartbeat is not None:
238
            self._heartbeat.cancel()
239
        self._heartbeat = None
240
        self._emit_heartbeat()  # one final beat to flush pending changes
241
242 View Code Duplication
    def _emit_queued(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
243
        self.status = 'QUEUED'
244
        queue_time = datetime.datetime.utcnow()
245
        self.meta_info['queue_time'] = queue_time
246
        command = join_paths(self.main_function.prefix,
247
                             self.main_function.signature.name)
248
        self.run_logger.info("Queuing-up command '%s'", command)
249
        for observer in self.observers:
250
            if hasattr(observer, 'queued_event'):
251
                _id = observer.queued_event(
252
                    ex_info=self.experiment_info,
253
                    command=command,
254
                    queue_time=queue_time,
255
                    config=self.config,
256
                    meta_info=self.meta_info,
257
                    _id=self._id
258
                )
259
                if self._id is None:
260
                    self._id = _id
261
                # do not catch any exceptions on startup:
262
                # the experiment SHOULD fail if any of the observers fails
263
264
        if self._id is None:
265
            self.run_logger.info('Queued')
266
        else:
267
            self.run_logger.info('Queued-up run with ID "{}"'.format(self._id))
268
269 View Code Duplication
    def _emit_started(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
270
        self.status = 'RUNNING'
271
        self.start_time = datetime.datetime.utcnow()
272
        command = join_paths(self.main_function.prefix,
273
                             self.main_function.signature.name)
274
        self.run_logger.info("Running command '%s'", command)
275
        for observer in self.observers:
276
            if hasattr(observer, 'started_event'):
277
                _id = observer.started_event(
278
                    ex_info=self.experiment_info,
279
                    command=command,
280
                    host_info=self.host_info,
281
                    start_time=self.start_time,
282
                    config=self.config,
283
                    meta_info=self.meta_info,
284
                    _id=self._id
285
                )
286
                if self._id is None:
287
                    self._id = _id
288
                # do not catch any exceptions on startup:
289
                # the experiment SHOULD fail if any of the observers fails
290
        if self._id is None:
291
            self.run_logger.info('Started')
292
        else:
293
            self.run_logger.info('Started run with ID "{}"'.format(self._id))
294
295
    def _emit_heartbeat(self):
296
        beat_time = datetime.datetime.utcnow()
297
        self._get_captured_output()
298
        for observer in self.observers:
299
            self._safe_call(observer, 'heartbeat_event',
300
                            info=self.info,
301
                            captured_out=self.captured_out,
302
                            beat_time=beat_time)
303
304
    def _stop_time(self):
305
        self.stop_time = datetime.datetime.utcnow()
306
        elapsed_time = datetime.timedelta(
307
            seconds=round((self.stop_time - self.start_time).total_seconds()))
308
        return elapsed_time
309
310
    def _emit_completed(self, result):
311
        self.status = 'COMPLETED'
312
        for observer in self.observers:
313
            self._final_call(observer, 'completed_event',
314
                             stop_time=self.stop_time,
315
                             result=result)
316
317
    def _emit_interrupted(self, status):
318
        self.status = status
319
        elapsed_time = self._stop_time()
320
        self.run_logger.warning("Aborted after %s!", elapsed_time)
321
        for observer in self.observers:
322
            self._final_call(observer, 'interrupted_event',
323
                             interrupt_time=self.stop_time,
324
                             status=status)
325
326
    def _emit_failed(self, exc_type, exc_value, trace):
327
        self.status = 'FAILED'
328
        elapsed_time = self._stop_time()
329
        self.run_logger.error("Failed after %s!", elapsed_time)
330
        self.fail_trace = tb.format_exception(exc_type, exc_value, trace)
331
        for observer in self.observers:
332
            self._final_call(observer, 'failed_event',
333
                             fail_time=self.stop_time,
334
                             fail_trace=self.fail_trace)
335
336
    def _emit_resource_added(self, filename):
337
        for observer in self.observers:
338
            self._safe_call(observer, 'resource_event', filename=filename)
339
340
    def _emit_artifact_added(self, name, filename):
341
        for observer in self.observers:
342
            self._safe_call(observer, 'artifact_event',
343
                            name=name,
344
                            filename=filename)
345
346
    def _safe_call(self, obs, method, **kwargs):
347
        if obs not in self._failed_observers and hasattr(obs, method):
348
            try:
349
                getattr(obs, method)(**kwargs)
350
            except ObserverError as e:
351
                self._failed_observers.append(obs)
352
                self.run_logger.warning("An error ocurred in the '{}' "
353
                                        "observer: {}".format(obs, e))
354
            except:
355
                self._failed_observers.append(obs)
356
                raise
357
358
    def _final_call(self, observer, method, **kwargs):
359
        if hasattr(observer, method):
360
            try:
361
                getattr(observer, method)(**kwargs)
362
            except Exception:
363
                # Feels dirty to catch all exceptions, but it is just for
364
                # finishing up, so we don't want one observer to kill the
365
                # others
366
                self.run_logger.error(tb.format_exc())
367
368
    def _warn_about_failed_observers(self):
369
        for observer in self._failed_observers:
370
            self.run_logger.warning("The observer '{}' failed at some point "
371
                                    "during the run.".format(observer))
372
373
    def _execute_pre_run_hooks(self):
374
        for pr in self.pre_run_hooks:
375
            pr()
376
377
    def _execute_post_run_hooks(self):
378
        for pr in self.post_run_hooks:
379
            pr()
380
381
    def warn_if_unobserved(self):
382
        if not self.observers and not self.debug and not self.unobserved:
383
            self.run_logger.warning("No observers have been added to this run")
384