Completed
Pull Request — master (#1130)
by
unknown
04:55
created

MainLoop.__init__()   B

Complexity

Conditions 4

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
dl 0
loc 22
rs 8.9197
c 0
b 0
f 0
1
"""The event-based main loop of Blocks."""
2
import signal
3
import logging
4
import traceback
5
6
from blocks.config import config
7
from blocks.log import BACKENDS
8
from blocks.utils import reraise_as, unpack, change_recursion_limit
9
from blocks.utils.profile import Profile, Timer
10
from blocks.algorithms import GradientDescent
11
from blocks.extensions import CallbackName
12
from blocks.model import Model
13
14
logger = logging.getLogger(__name__)
15
16
error_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name error_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
17
18
Blocks will attempt to run `on_error` extensions, potentially saving data, \
19
before exiting and reraising the error. Note that the usual `after_training` \
20
extensions will *not* be run. The original error will be re-raised and also \
21
stored in the training log. Press CTRL + C to halt Blocks immediately."""
22
23
error_in_error_handling_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name error_in_error_handling_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
24
25
Blocks will now exit. The remaining `on_error` extensions will not be run."""
26
27
28
epoch_interrupt_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name epoch_interrupt_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
29
30
Blocks will complete this epoch of training and run extensions \
31
before exiting. If you do not want to complete this epoch, press CTRL + C \
32
again to stop training after the current batch."""
33
34
batch_interrupt_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name batch_interrupt_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
35
36
Blocks will complete the current batch and run extensions before exiting. If \
37
you do not want to complete this batch, press CTRL + C again. WARNING: Note \
38
that this will end training immediately, and extensions that e.g. save your \
39
training progress won't be run."""
40
41
no_model_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name no_model_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
42
43
A possible reason: one of your extensions requires the main loop to have \
44
a model. Check documentation of your extensions."""
45
46
47
class MainLoop(object):
48
    """The standard main loop of Blocks.
49
50
    In the `MainLoop` a model is trained by a training algorithm using data
51
    extracted from a data stream. This process is scrupulously documented
52
    in a log object.
53
54
    The `MainLoop` itself does very little: only fetching the data from the
55
    data stream and feeding it to the algorithm. It expects the extensions
56
    to do most of the job. A respective callback of every extension is
57
    called at every stage of training. The extensions should communicate
58
    between themselves and with the main loop object by means of making
59
    records in the log. For instance in order to stop the training
60
    procedure an extension can make a record
61
    `training_finish_requested=True` in the log. The main loop checks for
62
    such a record after every batch and every epoch and terminates when
63
    finds it.
64
65
    The `MainLoop` also handles interruption signal SIGINT for you (e.g.
66
    the one program receives when you press Ctrl + C). It notes this event
67
    in the log and at the next iteration or epoch end the main loop will
68
    be gracefully finished, with calling all necessary extension callbacks
69
    and waiting until they finish.
70
71
    Parameters
72
    ----------
73
    algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm`
74
        The training algorithm.
75
    data_stream : instance of :class:`.DataStream`.
76
        The data stream. Should support :class:`AbstractDataStream`
77
        interface from Fuel.
78
    model : instance of :class:`.ComputationGraph`, optional
79
        An annotated computation graph, typically represented
80
        by :class:`ComputationGraph` or :class:`Model` object. The main
81
        loop object uses the model only for optional sanity checks, it is
82
        here mainly for the main loop extensions.
83
    log : instance of :class:`.TrainingLog`, optional
84
        The log. When not given, a :class:`.TrainingLog` is created.
85
    log_backend : str
86
        The backend to use for the log. Currently `python` and `sqlite` are
87
        available. If not given, `config.log_backend` will be used. Ignored
88
        if `log` is passed.
89
    extensions : list of :class:`.TrainingExtension` instances
90
        The training extensions. Will be called in the same order as given
91
        here.
92
93
    """
94
    def __init__(self, algorithm, data_stream, model=None, log=None,
95
                 log_backend=None, extensions=None):
96
        if log is None:
97
            if log_backend is None:
98
                log_backend = config.log_backend
99
            log = BACKENDS[log_backend]()
100
        if extensions is None:
101
            extensions = []
102
103
        self.data_stream = data_stream
104
        self.algorithm = algorithm
105
        self.log = log
106
        self.extensions = extensions
107
108
        self.profile = Profile()
109
110
        self._model = model
111
112
        self.status['training_started'] = False
113
        self.status['epoch_started'] = False
114
        self.status['epoch_interrupt_received'] = False
115
        self.status['batch_interrupt_received'] = False
116
117
    @property
118
    def model(self):
119
        if not self._model:
120
            raise AttributeError("no model in this main loop" +
121
                                 no_model_message)
122
        return self._model
123
124
    @property
125
    def iteration_state(self):
126
        """Quick access to the (data stream, epoch iterator) pair."""
127
        return (self.data_stream, self.epoch_iterator)
128
129
    @iteration_state.setter
130
    def iteration_state(self, value):
131
        (self.data_stream, self.epoch_iterator) = value
132
133
    @property
134
    def status(self):
135
        """A shortcut for `self.log.status`."""
136
        return self.log.status
137
138
    def run(self):
139
        """Starts the main loop.
140
141
        The main loop ends when a training extension makes
142
        a `training_finish_requested` record in the log.
143
144
        """
145
        # This should do nothing if the user has already configured
146
        # logging, and will it least enable error messages otherwise.
147
        logging.basicConfig()
148
149
        # If this is resumption from a checkpoint, it is crucial to
150
        # reset `profile.current`. Otherwise, it simply does not hurt.
151
        self.profile.current = []
152
153
        # Sanity check for the most common case
154
        if (self._model and isinstance(self._model, Model) and
155
                isinstance(self.algorithm, GradientDescent)):
156
            if not (set(self._model.get_parameter_dict().values()) ==
157
                    set(self.algorithm.parameters)):
158
                logger.warning("different parameters for model and algorithm")
159
160
        with change_recursion_limit(config.recursion_limit), \
161
                self.log.writer() as log:
162
            self.original_sigint_handler = signal.signal(
163
                signal.SIGINT, self._handle_epoch_interrupt(log))
164
            self.original_sigterm_handler = signal.signal(
165
                signal.SIGTERM, self._handle_batch_interrupt(log))
166
            try:
167
                logger.info("Entered the main loop")
168
                if not self.status['training_started']:
169
                    for extension in self.extensions:
170
                        extension.main_loop = self
171
                    self._run_extensions('before_training')
172
                    with Timer('initialization', self.profile):
173
                        self.algorithm.initialize()
174
                    self.status['training_started'] = True
175
                # We can not write "else:" here because extensions
176
                # called "before_training" could have changed the status
177
                # of the main loop.
178
                if log.status['iterations_done'] > 0:
179
                    log.resume()
180
                    self._run_extensions('on_resumption')
181
                    self.status['epoch_interrupt_received'] = False
182
                    self.status['batch_interrupt_received'] = False
183
                with Timer('training', self.profile):
184
                    while self._run_epoch(log):
185
                        pass
186
            except TrainingFinish:
187
                log.current_row['training_finished'] = True
188
            except Exception as e:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
189
                self._restore_signal_handlers()
190
                log.current_row['got_exception'] = traceback.format_exc()
191
                logger.error("Error occurred during training." + error_message)
192
                try:
193
                    self._run_extensions('on_error')
194
                except Exception:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
195
                    logger.error(traceback.format_exc())
196
                    logger.error("Error occurred when running extensions." +
197
                                 error_in_error_handling_message)
198
                reraise_as(e)
199
            finally:
200
                self._restore_signal_handlers()
201
                if log.current_row.get('training_finished', False):
202
                    self._run_extensions('after_training')
203
                if config.profile:
204
                    self.profile.report()
205
206
    def find_extension(self, name):
207
        """Find an extension with a given name.
208
209
        Parameters
210
        ----------
211
        name : str
212
            The name of the extension looked for.
213
214
        Notes
215
        -----
216
        Will crash if there no or several extension found.
217
218
        """
219
        return unpack([extension for extension in self.extensions
220
                       if extension.name == name], singleton=True)
221
222
    def _run_epoch(self, log):
223
        if not self.status.get('epoch_started', False):
224
            try:
225
                self.log.status['received_first_batch'] = False
226
                self.epoch_iterator = (self.data_stream.
227
                                       get_epoch_iterator(as_dict=True))
228
            except StopIteration:
229
                return False
230
            self.status['epoch_started'] = True
231
            self._run_extensions('before_epoch')
232
        with Timer('epoch', self.profile):
233
            while self._run_iteration(log):
234
                pass
235
        self.status['epoch_started'] = False
236
        self.status['epochs_done'] += 1
237
        # Log might not allow mutating objects, so use += instead of append
238
        self.status['_epoch_ends'] += [self.status['iterations_done']]
239
        self._run_extensions('after_epoch')
240
        self._check_finish_training('epoch', log)
241
        return True
242
243
    def _run_iteration(self, log):
244
        log.new_iteration()
245
        try:
246
            with Timer('read_data', self.profile):
247
                batch = next(self.epoch_iterator)
248
        except StopIteration:
249
            if not self.log.status['received_first_batch']:
250
                reraise_as(ValueError("epoch iterator yielded zero batches"))
251
            return False
252
        self.log.status['received_first_batch'] = True
253
        self._run_extensions('before_batch', batch)
254
        with Timer('train', self.profile):
255
            self.algorithm.process_batch(batch)
256
        self.status['iterations_done'] += 1
257
        self._run_extensions('after_batch', batch)
258
        self._check_finish_training('batch', log)
259
        return True
260
261
    def _run_extensions(self, method_name, *args):
262
        with Timer(method_name, self.profile):
263
            for extension in self.extensions:
264
                with Timer(type(extension).__name__, self.profile):
265
                    extension.dispatch(CallbackName(method_name), *args)
266
267
    def _check_finish_training(self, level, log):
268
        """Checks whether the current training should be terminated.
269
270
        Parameters
271
        ----------
272
        level : {'epoch', 'batch'}
273
            The level at which this check was performed. In some cases, we
274
            only want to quit after completing the remained of the epoch.
275
276
        """
277
        # In case when keyboard interrupt is handled right at the end of
278
        # the iteration the corresponding log record can be found only in
279
        # the previous row.
280
        if (log.current_row.get('training_finish_requested', False) or
281
                self.status.get('batch_interrupt_received', False)):
282
            raise TrainingFinish
283
        if (level == 'epoch' and
284
                self.status.get('epoch_interrupt_received', False)):
285
            raise TrainingFinish
286
287
    def _handle_epoch_interrupt(self, log):
288
        def _handler(signal_number, frame):
0 ignored issues
show
Unused Code introduced by
The argument frame seems to be unused.
Loading history...
Unused Code introduced by
The argument signal_number seems to be unused.
Loading history...
289
            # Try to complete the current epoch if user presses CTRL + C
290
            logger.warning('Received epoch interrupt signal.' +
291
                           epoch_interrupt_message)
292
            signal.signal(signal.SIGINT, self._handle_batch_interrupt(log))
293
            log.current_row['epoch_interrupt_received'] = True
294
            # Add a record to the status. Unlike the log record it will be
295
            # easy to access at later iterations.
296
            self.status['epoch_interrupt_received'] = True
297
        return _handler
298
299
    def _handle_batch_interrupt(self, log):
300
        def _handler(signal_number, frame):
0 ignored issues
show
Unused Code introduced by
The argument frame seems to be unused.
Loading history...
Unused Code introduced by
The argument signal_number seems to be unused.
Loading history...
301
            # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
302
            self._restore_signal_handlers()
303
            logger.warning('Received batch interrupt signal.' +
304
                           batch_interrupt_message)
305
            log.current_row['batch_interrupt_received'] = True
306
            # Add a record to the status. Unlike the log record it will be
307
            # easy to access at later iterations.
308
            self.status['batch_interrupt_received'] = True
309
        return _handler
310
311
    def _restore_signal_handlers(self):
312
        signal.signal(signal.SIGINT, self.original_sigint_handler)
313
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
314
315
316
class TrainingFinish(Exception):
317
    """An exception raised when a finish request is found in the log."""
318
    pass
319