Completed
Push — master ( 55d315...17460c )
by Dmitry
55:12
created

blocks/main_loop.py (2 issues)

1
"""The event-based main loop of Blocks."""
2
import signal
3
import logging
4
import traceback
5
6
from blocks.config import config
7
from blocks.log import BACKENDS
8
from blocks.utils import reraise_as, unpack, change_recursion_limit
9
from blocks.utils.profile import Profile, Timer
10
from blocks.algorithms import GradientDescent
11
from blocks.extensions import CallbackName
12
from blocks.model import Model
13
14
logger = logging.getLogger(__name__)
15
16
error_message = """
17
18
Blocks will attempt to run `on_error` extensions, potentially saving data, \
19
before exiting and reraising the error. Note that the usual `after_training` \
20
extensions will *not* be run. The original error will be re-raised and also \
21
stored in the training log. Press CTRL + C to halt Blocks immediately."""
22
23
error_in_error_handling_message = """
24
25
Blocks will now exit. The remaining `on_error` extensions will not be run."""
26
27
28
epoch_interrupt_message = """
29
30
Blocks will complete this epoch of training and run extensions \
31
before exiting. If you do not want to complete this epoch, press CTRL + C \
32
again to stop training after the current batch."""
33
34
batch_interrupt_message = """
35
36
Blocks will complete the current batch and run extensions before exiting. If \
37
you do not want to complete this batch, press CTRL + C again. WARNING: Note \
38
that this will end training immediately, and extensions that e.g. save your \
39
training progress won't be run."""
40
41
no_model_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name no_model_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
42
43
A possible reason: one of your extensions requires the main loop to have \
44
a model. Check documentation of your extensions."""
45
46
47
class MainLoop(object):
48
    """The standard main loop of Blocks.
49
50
    In the `MainLoop` a model is trained by a training algorithm using data
51
    extracted from a data stream. This process is scrupulously documented
52
    in a log object.
53
54
    The `MainLoop` itself does very little: only fetching the data from the
55
    data stream and feeding it to the algorithm. It expects the extensions
56
    to do most of the job. A respective callback of every extension is
57
    called at every stage of training. The extensions should communicate
58
    between themselves and with the main loop object by means of making
59
    records in the log. For instance in order to stop the training
60
    procedure an extension can make a record
61
    `training_finish_requested=True` in the log. The main loop checks for
62
    such a record after every batch and every epoch and terminates when
63
    finds it.
64
65
    The `MainLoop` also handles interruption signal SIGINT for you (e.g.
66
    the one program receives when you press Ctrl + C). It notes this event
67
    in the log and at the next iteration or epoch end the main loop will
68
    be gracefully finished, with calling all necessary extension callbacks
69
    and waiting until they finish.
70
71
    Parameters
72
    ----------
73
    algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm`
74
        The training algorithm.
75
    data_stream : instance of :class:`.DataStream`.
76
        The data stream. Should support :class:`AbstractDataStream`
77
        interface from Fuel.
78
    model : instance of :class:`.ComputationGraph`, optional
79
        An annotated computation graph, typically represented
80
        by :class:`ComputationGraph` or :class:`Model` object. The main
81
        loop object uses the model only for optional sanity checks, it is
82
        here mainly for the main loop extensions.
83
    log : instance of :class:`.TrainingLog`, optional
84
        The log. When not given, a :class:`.TrainingLog` is created.
85
    log_backend : str
86
        The backend to use for the log. Currently `python` and `sqlite` are
87
        available. If not given, `config.log_backend` will be used. Ignored
88
        if `log` is passed.
89
    extensions : list of :class:`.TrainingExtension` instances
90
        The training extensions. Will be called in the same order as given
91
        here.
92
93
    """
94
    def __init__(self, algorithm, data_stream, model=None, log=None,
95
                 log_backend=None, extensions=None):
96
        if log is None:
97
            if log_backend is None:
98
                log_backend = config.log_backend
99
            log = BACKENDS[log_backend]()
100
        if extensions is None:
101
            extensions = []
102
103
        self.data_stream = data_stream
104
        self.algorithm = algorithm
105
        self.log = log
106
        self.extensions = extensions
107
108
        self.profile = Profile()
109
110
        self._model = model
111
112
        self.status['training_started'] = False
113
        self.status['epoch_started'] = False
114
        self.status['epoch_interrupt_received'] = False
115
        self.status['batch_interrupt_received'] = False
116
117
    @property
118
    def model(self):
119
        if not self._model:
120
            raise AttributeError("no model in this main loop" +
121
                                 no_model_message)
122
        return self._model
123
124
    @property
125
    def iteration_state(self):
126
        """Quick access to the (data stream, epoch iterator) pair."""
127
        return (self.data_stream, self.epoch_iterator)
128
129
    @iteration_state.setter
130
    def iteration_state(self, value):
131
        (self.data_stream, self.epoch_iterator) = value
132
133
    @property
134
    def status(self):
135
        """A shortcut for `self.log.status`."""
136
        return self.log.status
137
138
    def run(self):
139
        """Starts the main loop.
140
141
        The main loop ends when a training extension makes
142
        a `training_finish_requested` record in the log.
143
144
        """
145
        # This should do nothing if the user has already configured
146
        # logging, and will it least enable error messages otherwise.
147
        logging.basicConfig()
148
149
        # If this is resumption from a checkpoint, it is crucial to
150
        # reset `profile.current`. Otherwise, it simply does not hurt.
151
        self.profile.current = []
152
153
        # Sanity check for the most common case
154
        if (self._model and isinstance(self._model, Model) and
155
                isinstance(self.algorithm, GradientDescent)):
156
            if not (set(self._model.get_parameter_dict().values()) ==
157
                    set(self.algorithm.parameters)):
158
                logger.warning("different parameters for model and algorithm")
159
160
        with change_recursion_limit(config.recursion_limit):
161
            self.original_sigint_handler = signal.signal(
162
                signal.SIGINT, self._handle_epoch_interrupt)
163
            self.original_sigterm_handler = signal.signal(
164
                signal.SIGTERM, self._handle_batch_interrupt)
165
            try:
166
                logger.info("Entered the main loop")
167
                if not self.status['training_started']:
168
                    for extension in self.extensions:
169
                        extension.main_loop = self
170
                    self._run_extensions('before_training')
171
                    with Timer('initialization', self.profile):
172
                        self.algorithm.initialize()
173
                    self.status['training_started'] = True
174
                # We can not write "else:" here because extensions
175
                # called "before_training" could have changed the status
176
                # of the main loop.
177
                if self.log.status['iterations_done'] > 0:
178
                    self.log.resume()
179
                    self._run_extensions('on_resumption')
180
                    self.status['epoch_interrupt_received'] = False
181
                    self.status['batch_interrupt_received'] = False
182
                with Timer('training', self.profile):
183
                    while self._run_epoch():
184
                        pass
185
            except TrainingFinish:
186
                self.log.current_row['training_finished'] = True
187
            except Exception as e:
0 ignored issues
show
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
188
                self._restore_signal_handlers()
189
                self.log.current_row['got_exception'] = traceback.format_exc()
190
                logger.error("Error occured during training." + error_message)
191
                try:
192
                    self._run_extensions('on_error', e)
193
                except Exception:
194
                    logger.error(traceback.format_exc())
195
                    logger.error("Error occured when running extensions." +
196
                                 error_in_error_handling_message)
197
                reraise_as(e)
198
            finally:
199
                self._restore_signal_handlers()
200
                if self.log.current_row.get('training_finished', False):
201
                    self._run_extensions('after_training')
202
                if config.profile:
203
                    self.profile.report()
204
205
    def find_extension(self, name):
206
        """Find an extension with a given name.
207
208
        Parameters
209
        ----------
210
        name : str
211
            The name of the extension looked for.
212
213
        Notes
214
        -----
215
        Will crash if there no or several extension found.
216
217
        """
218
        return unpack([extension for extension in self.extensions
219
                       if extension.name == name], singleton=True)
220
221
    def _run_epoch(self):
222
        if not self.status.get('epoch_started', False):
223
            try:
224
                self.log.status['received_first_batch'] = False
225
                self.epoch_iterator = (self.data_stream.
226
                                       get_epoch_iterator(as_dict=True))
227
            except StopIteration:
228
                return False
229
            self.status['epoch_started'] = True
230
            self._run_extensions('before_epoch')
231
        with Timer('epoch', self.profile):
232
            while self._run_iteration():
233
                pass
234
        self.status['epoch_started'] = False
235
        self.status['epochs_done'] += 1
236
        # Log might not allow mutating objects, so use += instead of append
237
        self.status['_epoch_ends'] += [self.status['iterations_done']]
238
        self._run_extensions('after_epoch')
239
        self._check_finish_training('epoch')
240
        return True
241
242
    def _run_iteration(self):
243
        try:
244
            with Timer('read_data', self.profile):
245
                batch = next(self.epoch_iterator)
246
        except StopIteration:
247
            if not self.log.status['received_first_batch']:
248
                reraise_as(ValueError("epoch iterator yielded zero batches"))
249
            return False
250
        self.log.status['received_first_batch'] = True
251
        self._run_extensions('before_batch', batch)
252
        with Timer('train', self.profile):
253
            self.algorithm.process_batch(batch)
254
        self.status['iterations_done'] += 1
255
        self._run_extensions('after_batch', batch)
256
        self._check_finish_training('batch')
257
        return True
258
259
    def _run_extensions(self, method_name, *args):
260
        with Timer(method_name, self.profile):
261
            for extension in self.extensions:
262
                with Timer(type(extension).__name__, self.profile):
263
                    extension.dispatch(CallbackName(method_name), *args)
264
265
    def _check_finish_training(self, level):
266
        """Checks whether the current training should be terminated.
267
268
        Parameters
269
        ----------
270
        level : {'epoch', 'batch'}
271
            The level at which this check was performed. In some cases, we
272
            only want to quit after completing the remained of the epoch.
273
274
        """
275
        # In case when keyboard interrupt is handled right at the end of
276
        # the iteration the corresponding log record can be found only in
277
        # the previous row.
278
        if (self.log.current_row.get('training_finish_requested', False) or
279
                self.status.get('batch_interrupt_received', False)):
280
            raise TrainingFinish
281
        if (level == 'epoch' and
282
                self.status.get('epoch_interrupt_received', False)):
283
            raise TrainingFinish
284
285
    def _handle_epoch_interrupt(self, signal_number, frame):
286
        # Try to complete the current epoch if user presses CTRL + C
287
        logger.warning('Received epoch interrupt signal.' +
288
                       epoch_interrupt_message)
289
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
290
        self.log.current_row['epoch_interrupt_received'] = True
291
        # Add a record to the status. Unlike the log record it will be
292
        # easy to access at later iterations.
293
        self.status['epoch_interrupt_received'] = True
294
295
    def _handle_batch_interrupt(self, signal_number, frame):
296
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
297
        self._restore_signal_handlers()
298
        logger.warning('Received batch interrupt signal.' +
299
                       batch_interrupt_message)
300
        self.log.current_row['batch_interrupt_received'] = True
301
        # Add a record to the status. Unlike the log record it will be
302
        # easy to access at later iterations.
303
        self.status['batch_interrupt_received'] = True
304
305
    def _restore_signal_handlers(self):
306
        signal.signal(signal.SIGINT, self.original_sigint_handler)
307
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
308
309
310
class TrainingFinish(Exception):
311
    """An exception raised when a finish request is found in the log."""
312
    pass
313