MainLoop   B
last analyzed

Complexity

Total Complexity 51

Size/Duplication

Total Lines 262
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 262
rs 8.3206
c 0
b 0
f 0
wmc 51

13 Methods

Rating   Name   Duplication   Size   Complexity  
A find_extension() 0 15 3
B _run_epoch() 0 20 5
A _handle_batch_interrupt() 0 9 1
A iteration_state() 0 4 1
B _run_iteration() 0 16 5
A _restore_signal_handlers() 0 3 1
F run() 0 66 17
B __init__() 0 23 4
A _handle_epoch_interrupt() 0 9 1
B _check_finish_training() 0 19 5
A _run_extensions() 0 5 4
A status() 0 4 1
A model() 0 6 2

How to fix   Complexity   

Complex Class

Complex classes like MainLoop often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""The event-based main loop of Blocks."""
2
import signal
3
import logging
4
import traceback
5
6
from blocks.config import config
7
from blocks.log import BACKENDS
8
from blocks.utils import reraise_as, unpack, change_recursion_limit
9
from blocks.utils.profile import Profile, Timer
10
from blocks.algorithms import GradientDescent
11
from blocks.extensions import CallbackName
12
from blocks.model import Model
13
14
logger = logging.getLogger(__name__)
15
16
error_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name error_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
17
18
Blocks will attempt to run `on_error` extensions, potentially saving data, \
19
before exiting and reraising the error. Note that the usual `after_training` \
20
extensions will *not* be run. The original error will be re-raised and also \
21
stored in the training log. Press CTRL + C to halt Blocks immediately."""
22
23
error_in_error_handling_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name error_in_error_handling_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
24
25
Blocks will now exit. The remaining `on_error` extensions will not be run."""
26
27
28
epoch_interrupt_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name epoch_interrupt_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
29
30
Blocks will complete this epoch of training and run extensions \
31
before exiting. If you do not want to complete this epoch, press CTRL + C \
32
again to stop training after the current batch."""
33
34
batch_interrupt_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name batch_interrupt_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
35
36
Blocks will complete the current batch and run extensions before exiting. If \
37
you do not want to complete this batch, press CTRL + C again. WARNING: Note \
38
that this will end training immediately, and extensions that e.g. save your \
39
training progress won't be run."""
40
41
no_model_message = """
0 ignored issues
show
Coding Style Naming introduced by
The name no_model_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
42
43
A possible reason: one of your extensions requires the main loop to have \
44
a model. Check documentation of your extensions."""
45
46
47
class MainLoop(object):
48
    """The standard main loop of Blocks.
49
50
    In the `MainLoop` a model is trained by a training algorithm using data
51
    extracted from a data stream. This process is scrupulously documented
52
    in a log object.
53
54
    The `MainLoop` itself does very little: only fetching the data from the
55
    data stream and feeding it to the algorithm. It expects the extensions
56
    to do most of the job. A respective callback of every extension is
57
    called at every stage of training. The extensions should communicate
58
    between themselves and with the main loop object by means of making
59
    records in the log. For instance in order to stop the training
60
    procedure an extension can make a record
61
    `training_finish_requested=True` in the log. The main loop checks for
62
    such a record after every batch and every epoch and terminates when
63
    finds it.
64
65
    The `MainLoop` also handles interruption signal SIGINT for you (e.g.
66
    the one program receives when you press Ctrl + C). It notes this event
67
    in the log and at the next iteration or epoch end the main loop will
68
    be gracefully finished, with calling all necessary extension callbacks
69
    and waiting until they finish.
70
71
    Parameters
72
    ----------
73
    algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm`
74
        The training algorithm.
75
    data_stream : instance of :class:`.DataStream`.
76
        The data stream. Should support :class:`AbstractDataStream`
77
        interface from Fuel.
78
    model : instance of :class:`.ComputationGraph`, optional
79
        An annotated computation graph, typically represented
80
        by :class:`ComputationGraph` or :class:`Model` object. The main
81
        loop object uses the model only for optional sanity checks, it is
82
        here mainly for the main loop extensions.
83
    log : instance of :class:`.TrainingLog`, optional
84
        The log. When not given, a :class:`.TrainingLog` is created.
85
    log_backend : str
86
        The backend to use for the log. Currently `python` and `sqlite` are
87
        available. If not given, `config.log_backend` will be used. Ignored
88
        if `log` is passed.
89
    extensions : list of :class:`.TrainingExtension` instances
90
        The training extensions. Will be called in the same order as given
91
        here.
92
93
    """
94
    def __init__(self, algorithm, data_stream, model=None, log=None,
95
                 log_backend=None, extensions=None):
96
        if log is None:
97
            if log_backend is None:
98
                log_backend = config.log_backend
99
            log = BACKENDS[log_backend]()
100
        if extensions is None:
101
            extensions = []
102
103
        self.data_stream = data_stream
104
        self.epoch_iterator = None
105
        self.algorithm = algorithm
106
        self.log = log
107
        self.extensions = extensions
108
109
        self.profile = Profile()
110
111
        self._model = model
112
113
        self.status['training_started'] = False
114
        self.status['epoch_started'] = False
115
        self.status['epoch_interrupt_received'] = False
116
        self.status['batch_interrupt_received'] = False
117
118
    @property
119
    def model(self):
120
        if not self._model:
121
            raise AttributeError("no model in this main loop" +
122
                                 no_model_message)
123
        return self._model
124
125
    @property
126
    def iteration_state(self):
127
        """Quick access to the (data stream, epoch iterator) pair."""
128
        return (self.data_stream, self.epoch_iterator)
129
130
    @iteration_state.setter
131
    def iteration_state(self, value):
132
        (self.data_stream, self.epoch_iterator) = value
133
134
    @property
135
    def status(self):
136
        """A shortcut for `self.log.status`."""
137
        return self.log.status
138
139
    def run(self):
140
        """Starts the main loop.
141
142
        The main loop ends when a training extension makes
143
        a `training_finish_requested` record in the log.
144
145
        """
146
        # This should do nothing if the user has already configured
147
        # logging, and will it least enable error messages otherwise.
148
        logging.basicConfig()
149
150
        # If this is resumption from a checkpoint, it is crucial to
151
        # reset `profile.current`. Otherwise, it simply does not hurt.
152
        self.profile.current = []
153
154
        # Sanity check for the most common case
155
        if (self._model and isinstance(self._model, Model) and
156
                isinstance(self.algorithm, GradientDescent)):
157
            if not (set(self._model.get_parameter_dict().values()) ==
158
                    set(self.algorithm.parameters)):
159
                logger.warning("different parameters for model and algorithm")
160
161
        with change_recursion_limit(config.recursion_limit):
162
            self.original_sigint_handler = signal.signal(
163
                signal.SIGINT, self._handle_epoch_interrupt)
164
            self.original_sigterm_handler = signal.signal(
165
                signal.SIGTERM, self._handle_batch_interrupt)
166
            try:
167
                logger.info("Entered the main loop")
168
                if not self.status['training_started']:
169
                    for extension in self.extensions:
170
                        extension.main_loop = self
171
                    self._run_extensions('before_training')
172
                    with Timer('initialization', self.profile):
173
                        self.algorithm.initialize()
174
                    self.status['training_started'] = True
175
                # We can not write "else:" here because extensions
176
                # called "before_training" could have changed the status
177
                # of the main loop.
178
                if self.log.status['iterations_done'] > 0:
179
                    self.log.resume()
180
                    self._run_extensions('on_resumption')
181
                    self.status['epoch_interrupt_received'] = False
182
                    self.status['batch_interrupt_received'] = False
183
                with Timer('training', self.profile):
184
                    while self._run_epoch():
185
                        pass
186
            except TrainingFinish:
187
                self.log.current_row['training_finished'] = True
188
            except Exception as e:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
189
                self._restore_signal_handlers()
190
                self.log.current_row['got_exception'] = traceback.format_exc()
191
                logger.error("Error occured during training." + error_message)
192
                try:
193
                    self._run_extensions('on_error', e)
194
                except Exception:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
195
                    logger.error(traceback.format_exc())
196
                    logger.error("Error occured when running extensions." +
197
                                 error_in_error_handling_message)
198
                reraise_as(e)
199
            finally:
200
                self._restore_signal_handlers()
201
                if self.log.current_row.get('training_finished', False):
202
                    self._run_extensions('after_training')
203
                if config.profile:
204
                    self.profile.report()
205
206
    def find_extension(self, name):
207
        """Find an extension with a given name.
208
209
        Parameters
210
        ----------
211
        name : str
212
            The name of the extension looked for.
213
214
        Notes
215
        -----
216
        Will crash if there no or several extension found.
217
218
        """
219
        return unpack([extension for extension in self.extensions
220
                       if extension.name == name], singleton=True)
221
222
    def _run_epoch(self):
223
        if not self.status.get('epoch_started', False):
224
            try:
225
                self.log.status['received_first_batch'] = False
226
                self.epoch_iterator = (self.data_stream.
227
                                       get_epoch_iterator(as_dict=True))
228
            except StopIteration:
229
                return False
230
            self.status['epoch_started'] = True
231
            self._run_extensions('before_epoch')
232
        with Timer('epoch', self.profile):
233
            while self._run_iteration():
234
                pass
235
        self.status['epoch_started'] = False
236
        self.status['epochs_done'] += 1
237
        # Log might not allow mutating objects, so use += instead of append
238
        self.status['_epoch_ends'] += [self.status['iterations_done']]
239
        self._run_extensions('after_epoch')
240
        self._check_finish_training('epoch')
241
        return True
242
243
    def _run_iteration(self):
244
        try:
245
            with Timer('read_data', self.profile):
246
                batch = next(self.epoch_iterator)
247
        except StopIteration:
248
            if not self.log.status['received_first_batch']:
249
                reraise_as(ValueError("epoch iterator yielded zero batches"))
250
            return False
251
        self.log.status['received_first_batch'] = True
252
        self._run_extensions('before_batch', batch)
253
        with Timer('train', self.profile):
254
            self.algorithm.process_batch(batch)
255
        self.status['iterations_done'] += 1
256
        self._run_extensions('after_batch', batch)
257
        self._check_finish_training('batch')
258
        return True
259
260
    def _run_extensions(self, method_name, *args):
261
        with Timer(method_name, self.profile):
262
            for extension in self.extensions:
263
                with Timer(type(extension).__name__, self.profile):
264
                    extension.dispatch(CallbackName(method_name), *args)
265
266
    def _check_finish_training(self, level):
267
        """Checks whether the current training should be terminated.
268
269
        Parameters
270
        ----------
271
        level : {'epoch', 'batch'}
272
            The level at which this check was performed. In some cases, we
273
            only want to quit after completing the remained of the epoch.
274
275
        """
276
        # In case when keyboard interrupt is handled right at the end of
277
        # the iteration the corresponding log record can be found only in
278
        # the previous row.
279
        if (self.log.current_row.get('training_finish_requested', False) or
280
                self.status.get('batch_interrupt_received', False)):
281
            raise TrainingFinish
282
        if (level == 'epoch' and
283
                self.status.get('epoch_interrupt_received', False)):
284
            raise TrainingFinish
285
286
    def _handle_epoch_interrupt(self, signal_number, frame):
0 ignored issues
show
Unused Code introduced by
The argument signal_number seems to be unused.
Loading history...
Unused Code introduced by
The argument frame seems to be unused.
Loading history...
287
        # Try to complete the current epoch if user presses CTRL + C
288
        logger.warning('Received epoch interrupt signal.' +
289
                       epoch_interrupt_message)
290
        signal.signal(signal.SIGINT, self._handle_batch_interrupt)
291
        self.log.current_row['epoch_interrupt_received'] = True
292
        # Add a record to the status. Unlike the log record it will be
293
        # easy to access at later iterations.
294
        self.status['epoch_interrupt_received'] = True
295
296
    def _handle_batch_interrupt(self, signal_number, frame):
0 ignored issues
show
Unused Code introduced by
The argument signal_number seems to be unused.
Loading history...
Unused Code introduced by
The argument frame seems to be unused.
Loading history...
297
        # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch
298
        self._restore_signal_handlers()
299
        logger.warning('Received batch interrupt signal.' +
300
                       batch_interrupt_message)
301
        self.log.current_row['batch_interrupt_received'] = True
302
        # Add a record to the status. Unlike the log record it will be
303
        # easy to access at later iterations.
304
        self.status['batch_interrupt_received'] = True
305
306
    def _restore_signal_handlers(self):
307
        signal.signal(signal.SIGINT, self.original_sigint_handler)
308
        signal.signal(signal.SIGTERM, self.original_sigterm_handler)
309
310
311
class TrainingFinish(Exception):
312
    """An exception raised when a finish request is found in the log."""
313
    pass
314