1
|
|
|
"""The event-based main loop of Blocks.""" |
2
|
|
|
import signal |
3
|
|
|
import logging |
4
|
|
|
import traceback |
5
|
|
|
|
6
|
|
|
from blocks.config import config |
7
|
|
|
from blocks.log import BACKENDS |
8
|
|
|
from blocks.utils import reraise_as, unpack, change_recursion_limit |
9
|
|
|
from blocks.utils.profile import Profile, Timer |
10
|
|
|
from blocks.algorithms import GradientDescent |
11
|
|
|
from blocks.extensions import CallbackName |
12
|
|
|
from blocks.model import Model |
13
|
|
|
|
14
|
|
|
logger = logging.getLogger(__name__) |
15
|
|
|
|
16
|
|
|
error_message = """ |
|
|
|
|
17
|
|
|
|
18
|
|
|
Blocks will attempt to run `on_error` extensions, potentially saving data, \ |
19
|
|
|
before exiting and reraising the error. Note that the usual `after_training` \ |
20
|
|
|
extensions will *not* be run. The original error will be re-raised and also \ |
21
|
|
|
stored in the training log. Press CTRL + C to halt Blocks immediately.""" |
22
|
|
|
|
23
|
|
|
error_in_error_handling_message = """ |
|
|
|
|
24
|
|
|
|
25
|
|
|
Blocks will now exit. The remaining `on_error` extensions will not be run.""" |
26
|
|
|
|
27
|
|
|
|
28
|
|
|
epoch_interrupt_message = """ |
|
|
|
|
29
|
|
|
|
30
|
|
|
Blocks will complete this epoch of training and run extensions \ |
31
|
|
|
before exiting. If you do not want to complete this epoch, press CTRL + C \ |
32
|
|
|
again to stop training after the current batch.""" |
33
|
|
|
|
34
|
|
|
batch_interrupt_message = """ |
|
|
|
|
35
|
|
|
|
36
|
|
|
Blocks will complete the current batch and run extensions before exiting. If \ |
37
|
|
|
you do not want to complete this batch, press CTRL + C again. WARNING: Note \ |
38
|
|
|
that this will end training immediately, and extensions that e.g. save your \ |
39
|
|
|
training progress won't be run.""" |
40
|
|
|
|
41
|
|
|
no_model_message = """ |
|
|
|
|
42
|
|
|
|
43
|
|
|
A possible reason: one of your extensions requires the main loop to have \ |
44
|
|
|
a model. Check documentation of your extensions.""" |
45
|
|
|
|
46
|
|
|
|
47
|
|
|
class MainLoop(object): |
48
|
|
|
"""The standard main loop of Blocks. |
49
|
|
|
|
50
|
|
|
In the `MainLoop` a model is trained by a training algorithm using data |
51
|
|
|
extracted from a data stream. This process is scrupulously documented |
52
|
|
|
in a log object. |
53
|
|
|
|
54
|
|
|
The `MainLoop` itself does very little: only fetching the data from the |
55
|
|
|
data stream and feeding it to the algorithm. It expects the extensions |
56
|
|
|
to do most of the job. A respective callback of every extension is |
57
|
|
|
called at every stage of training. The extensions should communicate |
58
|
|
|
between themselves and with the main loop object by means of making |
59
|
|
|
records in the log. For instance in order to stop the training |
60
|
|
|
procedure an extension can make a record |
61
|
|
|
`training_finish_requested=True` in the log. The main loop checks for |
62
|
|
|
such a record after every batch and every epoch and terminates when |
63
|
|
|
finds it. |
64
|
|
|
|
65
|
|
|
The `MainLoop` also handles interruption signal SIGINT for you (e.g. |
66
|
|
|
the one program receives when you press Ctrl + C). It notes this event |
67
|
|
|
in the log and at the next iteration or epoch end the main loop will |
68
|
|
|
be gracefully finished, with calling all necessary extension callbacks |
69
|
|
|
and waiting until they finish. |
70
|
|
|
|
71
|
|
|
Parameters |
72
|
|
|
---------- |
73
|
|
|
algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm` |
74
|
|
|
The training algorithm. |
75
|
|
|
data_stream : instance of :class:`.DataStream`. |
76
|
|
|
The data stream. Should support :class:`AbstractDataStream` |
77
|
|
|
interface from Fuel. |
78
|
|
|
model : instance of :class:`.ComputationGraph`, optional |
79
|
|
|
An annotated computation graph, typically represented |
80
|
|
|
by :class:`ComputationGraph` or :class:`Model` object. The main |
81
|
|
|
loop object uses the model only for optional sanity checks, it is |
82
|
|
|
here mainly for the main loop extensions. |
83
|
|
|
log : instance of :class:`.TrainingLog`, optional |
84
|
|
|
The log. When not given, a :class:`.TrainingLog` is created. |
85
|
|
|
log_backend : str |
86
|
|
|
The backend to use for the log. Currently `python` and `sqlite` are |
87
|
|
|
available. If not given, `config.log_backend` will be used. Ignored |
88
|
|
|
if `log` is passed. |
89
|
|
|
extensions : list of :class:`.TrainingExtension` instances |
90
|
|
|
The training extensions. Will be called in the same order as given |
91
|
|
|
here. |
92
|
|
|
|
93
|
|
|
""" |
94
|
|
|
def __init__(self, algorithm, data_stream, model=None, log=None, |
95
|
|
|
log_backend=None, extensions=None): |
96
|
|
|
if log is None: |
97
|
|
|
if log_backend is None: |
98
|
|
|
log_backend = config.log_backend |
99
|
|
|
log = BACKENDS[log_backend]() |
100
|
|
|
if extensions is None: |
101
|
|
|
extensions = [] |
102
|
|
|
|
103
|
|
|
self.data_stream = data_stream |
104
|
|
|
self.epoch_iterator = None |
105
|
|
|
self.algorithm = algorithm |
106
|
|
|
self.log = log |
107
|
|
|
self.extensions = extensions |
108
|
|
|
|
109
|
|
|
self.profile = Profile() |
110
|
|
|
|
111
|
|
|
self._model = model |
112
|
|
|
|
113
|
|
|
self.status['training_started'] = False |
114
|
|
|
self.status['epoch_started'] = False |
115
|
|
|
self.status['epoch_interrupt_received'] = False |
116
|
|
|
self.status['batch_interrupt_received'] = False |
117
|
|
|
|
118
|
|
|
@property |
119
|
|
|
def model(self): |
120
|
|
|
if not self._model: |
121
|
|
|
raise AttributeError("no model in this main loop" + |
122
|
|
|
no_model_message) |
123
|
|
|
return self._model |
124
|
|
|
|
125
|
|
|
@property |
126
|
|
|
def iteration_state(self): |
127
|
|
|
"""Quick access to the (data stream, epoch iterator) pair.""" |
128
|
|
|
return (self.data_stream, self.epoch_iterator) |
129
|
|
|
|
130
|
|
|
@iteration_state.setter |
131
|
|
|
def iteration_state(self, value): |
132
|
|
|
(self.data_stream, self.epoch_iterator) = value |
133
|
|
|
|
134
|
|
|
@property |
135
|
|
|
def status(self): |
136
|
|
|
"""A shortcut for `self.log.status`.""" |
137
|
|
|
return self.log.status |
138
|
|
|
|
139
|
|
|
def run(self): |
140
|
|
|
"""Starts the main loop. |
141
|
|
|
|
142
|
|
|
The main loop ends when a training extension makes |
143
|
|
|
a `training_finish_requested` record in the log. |
144
|
|
|
|
145
|
|
|
""" |
146
|
|
|
# This should do nothing if the user has already configured |
147
|
|
|
# logging, and will it least enable error messages otherwise. |
148
|
|
|
logging.basicConfig() |
149
|
|
|
|
150
|
|
|
# If this is resumption from a checkpoint, it is crucial to |
151
|
|
|
# reset `profile.current`. Otherwise, it simply does not hurt. |
152
|
|
|
self.profile.current = [] |
153
|
|
|
|
154
|
|
|
# Sanity check for the most common case |
155
|
|
|
if (self._model and isinstance(self._model, Model) and |
156
|
|
|
isinstance(self.algorithm, GradientDescent)): |
157
|
|
|
if not (set(self._model.get_parameter_dict().values()) == |
158
|
|
|
set(self.algorithm.parameters)): |
159
|
|
|
logger.warning("different parameters for model and algorithm") |
160
|
|
|
|
161
|
|
|
with change_recursion_limit(config.recursion_limit): |
162
|
|
|
self.original_sigint_handler = signal.signal( |
163
|
|
|
signal.SIGINT, self._handle_epoch_interrupt) |
164
|
|
|
self.original_sigterm_handler = signal.signal( |
165
|
|
|
signal.SIGTERM, self._handle_batch_interrupt) |
166
|
|
|
try: |
167
|
|
|
logger.info("Entered the main loop") |
168
|
|
|
if not self.status['training_started']: |
169
|
|
|
for extension in self.extensions: |
170
|
|
|
extension.main_loop = self |
171
|
|
|
self._run_extensions('before_training') |
172
|
|
|
with Timer('initialization', self.profile): |
173
|
|
|
self.algorithm.initialize() |
174
|
|
|
self.status['training_started'] = True |
175
|
|
|
# We can not write "else:" here because extensions |
176
|
|
|
# called "before_training" could have changed the status |
177
|
|
|
# of the main loop. |
178
|
|
|
if self.log.status['iterations_done'] > 0: |
179
|
|
|
self.log.resume() |
180
|
|
|
self._run_extensions('on_resumption') |
181
|
|
|
self.status['epoch_interrupt_received'] = False |
182
|
|
|
self.status['batch_interrupt_received'] = False |
183
|
|
|
with Timer('training', self.profile): |
184
|
|
|
while self._run_epoch(): |
185
|
|
|
pass |
186
|
|
|
except TrainingFinish: |
187
|
|
|
self.log.current_row['training_finished'] = True |
188
|
|
|
except Exception as e: |
|
|
|
|
189
|
|
|
self._restore_signal_handlers() |
190
|
|
|
self.log.current_row['got_exception'] = traceback.format_exc() |
191
|
|
|
logger.error("Error occured during training." + error_message) |
192
|
|
|
try: |
193
|
|
|
self._run_extensions('on_error', e) |
194
|
|
|
except Exception: |
|
|
|
|
195
|
|
|
logger.error(traceback.format_exc()) |
196
|
|
|
logger.error("Error occured when running extensions." + |
197
|
|
|
error_in_error_handling_message) |
198
|
|
|
reraise_as(e) |
199
|
|
|
finally: |
200
|
|
|
self._restore_signal_handlers() |
201
|
|
|
if self.log.current_row.get('training_finished', False): |
202
|
|
|
self._run_extensions('after_training') |
203
|
|
|
if config.profile: |
204
|
|
|
self.profile.report() |
205
|
|
|
|
206
|
|
|
def find_extension(self, name): |
207
|
|
|
"""Find an extension with a given name. |
208
|
|
|
|
209
|
|
|
Parameters |
210
|
|
|
---------- |
211
|
|
|
name : str |
212
|
|
|
The name of the extension looked for. |
213
|
|
|
|
214
|
|
|
Notes |
215
|
|
|
----- |
216
|
|
|
Will crash if there no or several extension found. |
217
|
|
|
|
218
|
|
|
""" |
219
|
|
|
return unpack([extension for extension in self.extensions |
220
|
|
|
if extension.name == name], singleton=True) |
221
|
|
|
|
222
|
|
|
def _run_epoch(self): |
223
|
|
|
if not self.status.get('epoch_started', False): |
224
|
|
|
try: |
225
|
|
|
self.log.status['received_first_batch'] = False |
226
|
|
|
self.epoch_iterator = (self.data_stream. |
227
|
|
|
get_epoch_iterator(as_dict=True)) |
228
|
|
|
except StopIteration: |
229
|
|
|
return False |
230
|
|
|
self.status['epoch_started'] = True |
231
|
|
|
self._run_extensions('before_epoch') |
232
|
|
|
with Timer('epoch', self.profile): |
233
|
|
|
while self._run_iteration(): |
234
|
|
|
pass |
235
|
|
|
self.status['epoch_started'] = False |
236
|
|
|
self.status['epochs_done'] += 1 |
237
|
|
|
# Log might not allow mutating objects, so use += instead of append |
238
|
|
|
self.status['_epoch_ends'] += [self.status['iterations_done']] |
239
|
|
|
self._run_extensions('after_epoch') |
240
|
|
|
self._check_finish_training('epoch') |
241
|
|
|
return True |
242
|
|
|
|
243
|
|
|
def _run_iteration(self): |
244
|
|
|
try: |
245
|
|
|
with Timer('read_data', self.profile): |
246
|
|
|
batch = next(self.epoch_iterator) |
247
|
|
|
except StopIteration: |
248
|
|
|
if not self.log.status['received_first_batch']: |
249
|
|
|
reraise_as(ValueError("epoch iterator yielded zero batches")) |
250
|
|
|
return False |
251
|
|
|
self.log.status['received_first_batch'] = True |
252
|
|
|
self._run_extensions('before_batch', batch) |
253
|
|
|
with Timer('train', self.profile): |
254
|
|
|
self.algorithm.process_batch(batch) |
255
|
|
|
self.status['iterations_done'] += 1 |
256
|
|
|
self._run_extensions('after_batch', batch) |
257
|
|
|
self._check_finish_training('batch') |
258
|
|
|
return True |
259
|
|
|
|
260
|
|
|
def _run_extensions(self, method_name, *args): |
261
|
|
|
with Timer(method_name, self.profile): |
262
|
|
|
for extension in self.extensions: |
263
|
|
|
with Timer(type(extension).__name__, self.profile): |
264
|
|
|
extension.dispatch(CallbackName(method_name), *args) |
265
|
|
|
|
266
|
|
|
def _check_finish_training(self, level): |
267
|
|
|
"""Checks whether the current training should be terminated. |
268
|
|
|
|
269
|
|
|
Parameters |
270
|
|
|
---------- |
271
|
|
|
level : {'epoch', 'batch'} |
272
|
|
|
The level at which this check was performed. In some cases, we |
273
|
|
|
only want to quit after completing the remained of the epoch. |
274
|
|
|
|
275
|
|
|
""" |
276
|
|
|
# In case when keyboard interrupt is handled right at the end of |
277
|
|
|
# the iteration the corresponding log record can be found only in |
278
|
|
|
# the previous row. |
279
|
|
|
if (self.log.current_row.get('training_finish_requested', False) or |
280
|
|
|
self.status.get('batch_interrupt_received', False)): |
281
|
|
|
raise TrainingFinish |
282
|
|
|
if (level == 'epoch' and |
283
|
|
|
self.status.get('epoch_interrupt_received', False)): |
284
|
|
|
raise TrainingFinish |
285
|
|
|
|
286
|
|
|
def _handle_epoch_interrupt(self, signal_number, frame): |
|
|
|
|
287
|
|
|
# Try to complete the current epoch if user presses CTRL + C |
288
|
|
|
logger.warning('Received epoch interrupt signal.' + |
289
|
|
|
epoch_interrupt_message) |
290
|
|
|
signal.signal(signal.SIGINT, self._handle_batch_interrupt) |
291
|
|
|
self.log.current_row['epoch_interrupt_received'] = True |
292
|
|
|
# Add a record to the status. Unlike the log record it will be |
293
|
|
|
# easy to access at later iterations. |
294
|
|
|
self.status['epoch_interrupt_received'] = True |
295
|
|
|
|
296
|
|
|
def _handle_batch_interrupt(self, signal_number, frame): |
|
|
|
|
297
|
|
|
# After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch |
298
|
|
|
self._restore_signal_handlers() |
299
|
|
|
logger.warning('Received batch interrupt signal.' + |
300
|
|
|
batch_interrupt_message) |
301
|
|
|
self.log.current_row['batch_interrupt_received'] = True |
302
|
|
|
# Add a record to the status. Unlike the log record it will be |
303
|
|
|
# easy to access at later iterations. |
304
|
|
|
self.status['batch_interrupt_received'] = True |
305
|
|
|
|
306
|
|
|
def _restore_signal_handlers(self): |
307
|
|
|
signal.signal(signal.SIGINT, self.original_sigint_handler) |
308
|
|
|
signal.signal(signal.SIGTERM, self.original_sigterm_handler) |
309
|
|
|
|
310
|
|
|
|
311
|
|
|
class TrainingFinish(Exception): |
312
|
|
|
"""An exception raised when a finish request is found in the log.""" |
313
|
|
|
pass |
314
|
|
|
|
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.