|
1
|
|
|
"""The event-based main loop of Blocks.""" |
|
2
|
|
|
import signal |
|
3
|
|
|
import logging |
|
4
|
|
|
import traceback |
|
5
|
|
|
|
|
6
|
|
|
from blocks.config import config |
|
7
|
|
|
from blocks.log import BACKENDS |
|
8
|
|
|
from blocks.utils import reraise_as, unpack, change_recursion_limit |
|
9
|
|
|
from blocks.utils.profile import Profile, Timer |
|
10
|
|
|
from blocks.algorithms import GradientDescent |
|
11
|
|
|
from blocks.extensions import CallbackName |
|
12
|
|
|
from blocks.model import Model |
|
13
|
|
|
|
|
14
|
|
|
logger = logging.getLogger(__name__) |
|
15
|
|
|
|
|
16
|
|
|
error_message = """ |
|
|
|
|
|
|
17
|
|
|
|
|
18
|
|
|
Blocks will attempt to run `on_error` extensions, potentially saving data, \ |
|
19
|
|
|
before exiting and reraising the error. Note that the usual `after_training` \ |
|
20
|
|
|
extensions will *not* be run. The original error will be re-raised and also \ |
|
21
|
|
|
stored in the training log. Press CTRL + C to halt Blocks immediately.""" |
|
22
|
|
|
|
|
23
|
|
|
error_in_error_handling_message = """ |
|
|
|
|
|
|
24
|
|
|
|
|
25
|
|
|
Blocks will now exit. The remaining `on_error` extensions will not be run.""" |
|
26
|
|
|
|
|
27
|
|
|
|
|
28
|
|
|
epoch_interrupt_message = """ |
|
|
|
|
|
|
29
|
|
|
|
|
30
|
|
|
Blocks will complete this epoch of training and run extensions \ |
|
31
|
|
|
before exiting. If you do not want to complete this epoch, press CTRL + C \ |
|
32
|
|
|
again to stop training after the current batch.""" |
|
33
|
|
|
|
|
34
|
|
|
batch_interrupt_message = """ |
|
|
|
|
|
|
35
|
|
|
|
|
36
|
|
|
Blocks will complete the current batch and run extensions before exiting. If \ |
|
37
|
|
|
you do not want to complete this batch, press CTRL + C again. WARNING: Note \ |
|
38
|
|
|
that this will end training immediately, and extensions that e.g. save your \ |
|
39
|
|
|
training progress won't be run.""" |
|
40
|
|
|
|
|
41
|
|
|
no_model_message = """ |
|
|
|
|
|
|
42
|
|
|
|
|
43
|
|
|
A possible reason: one of your extensions requires the main loop to have \ |
|
44
|
|
|
a model. Check documentation of your extensions.""" |
|
45
|
|
|
|
|
46
|
|
|
|
|
47
|
|
|
class MainLoop(object): |
|
48
|
|
|
"""The standard main loop of Blocks. |
|
49
|
|
|
|
|
50
|
|
|
In the `MainLoop` a model is trained by a training algorithm using data |
|
51
|
|
|
extracted from a data stream. This process is scrupulously documented |
|
52
|
|
|
in a log object. |
|
53
|
|
|
|
|
54
|
|
|
The `MainLoop` itself does very little: only fetching the data from the |
|
55
|
|
|
data stream and feeding it to the algorithm. It expects the extensions |
|
56
|
|
|
to do most of the job. A respective callback of every extension is |
|
57
|
|
|
called at every stage of training. The extensions should communicate |
|
58
|
|
|
between themselves and with the main loop object by means of making |
|
59
|
|
|
records in the log. For instance in order to stop the training |
|
60
|
|
|
procedure an extension can make a record |
|
61
|
|
|
`training_finish_requested=True` in the log. The main loop checks for |
|
62
|
|
|
such a record after every batch and every epoch and terminates when |
|
63
|
|
|
finds it. |
|
64
|
|
|
|
|
65
|
|
|
The `MainLoop` also handles interruption signal SIGINT for you (e.g. |
|
66
|
|
|
the one program receives when you press Ctrl + C). It notes this event |
|
67
|
|
|
in the log and at the next iteration or epoch end the main loop will |
|
68
|
|
|
be gracefully finished, with calling all necessary extension callbacks |
|
69
|
|
|
and waiting until they finish. |
|
70
|
|
|
|
|
71
|
|
|
Parameters |
|
72
|
|
|
---------- |
|
73
|
|
|
algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm` |
|
74
|
|
|
The training algorithm. |
|
75
|
|
|
data_stream : instance of :class:`.DataStream`. |
|
76
|
|
|
The data stream. Should support :class:`AbstractDataStream` |
|
77
|
|
|
interface from Fuel. |
|
78
|
|
|
model : instance of :class:`.ComputationGraph`, optional |
|
79
|
|
|
An annotated computation graph, typically represented |
|
80
|
|
|
by :class:`ComputationGraph` or :class:`Model` object. The main |
|
81
|
|
|
loop object uses the model only for optional sanity checks, it is |
|
82
|
|
|
here mainly for the main loop extensions. |
|
83
|
|
|
log : instance of :class:`.TrainingLog`, optional |
|
84
|
|
|
The log. When not given, a :class:`.TrainingLog` is created. |
|
85
|
|
|
log_backend : str |
|
86
|
|
|
The backend to use for the log. Currently `python` and `sqlite` are |
|
87
|
|
|
available. If not given, `config.log_backend` will be used. Ignored |
|
88
|
|
|
if `log` is passed. |
|
89
|
|
|
extensions : list of :class:`.TrainingExtension` instances |
|
90
|
|
|
The training extensions. Will be called in the same order as given |
|
91
|
|
|
here. |
|
92
|
|
|
|
|
93
|
|
|
""" |
|
94
|
|
|
def __init__(self, algorithm, data_stream, model=None, log=None, |
|
95
|
|
|
log_backend=None, extensions=None): |
|
96
|
|
|
if log is None: |
|
97
|
|
|
if log_backend is None: |
|
98
|
|
|
log_backend = config.log_backend |
|
99
|
|
|
log = BACKENDS[log_backend]() |
|
100
|
|
|
if extensions is None: |
|
101
|
|
|
extensions = [] |
|
102
|
|
|
|
|
103
|
|
|
self.data_stream = data_stream |
|
104
|
|
|
self.epoch_iterator = None |
|
105
|
|
|
self.algorithm = algorithm |
|
106
|
|
|
self.log = log |
|
107
|
|
|
self.extensions = extensions |
|
108
|
|
|
|
|
109
|
|
|
self.profile = Profile() |
|
110
|
|
|
|
|
111
|
|
|
self._model = model |
|
112
|
|
|
|
|
113
|
|
|
self.status['training_started'] = False |
|
114
|
|
|
self.status['epoch_started'] = False |
|
115
|
|
|
self.status['epoch_interrupt_received'] = False |
|
116
|
|
|
self.status['batch_interrupt_received'] = False |
|
117
|
|
|
|
|
118
|
|
|
@property |
|
119
|
|
|
def model(self): |
|
120
|
|
|
if not self._model: |
|
121
|
|
|
raise AttributeError("no model in this main loop" + |
|
122
|
|
|
no_model_message) |
|
123
|
|
|
return self._model |
|
124
|
|
|
|
|
125
|
|
|
@property |
|
126
|
|
|
def iteration_state(self): |
|
127
|
|
|
"""Quick access to the (data stream, epoch iterator) pair.""" |
|
128
|
|
|
return (self.data_stream, self.epoch_iterator) |
|
129
|
|
|
|
|
130
|
|
|
@iteration_state.setter |
|
131
|
|
|
def iteration_state(self, value): |
|
132
|
|
|
(self.data_stream, self.epoch_iterator) = value |
|
133
|
|
|
|
|
134
|
|
|
@property |
|
135
|
|
|
def status(self): |
|
136
|
|
|
"""A shortcut for `self.log.status`.""" |
|
137
|
|
|
return self.log.status |
|
138
|
|
|
|
|
139
|
|
|
def run(self): |
|
140
|
|
|
"""Starts the main loop. |
|
141
|
|
|
|
|
142
|
|
|
The main loop ends when a training extension makes |
|
143
|
|
|
a `training_finish_requested` record in the log. |
|
144
|
|
|
|
|
145
|
|
|
""" |
|
146
|
|
|
# This should do nothing if the user has already configured |
|
147
|
|
|
# logging, and will it least enable error messages otherwise. |
|
148
|
|
|
logging.basicConfig() |
|
149
|
|
|
|
|
150
|
|
|
# If this is resumption from a checkpoint, it is crucial to |
|
151
|
|
|
# reset `profile.current`. Otherwise, it simply does not hurt. |
|
152
|
|
|
self.profile.current = [] |
|
153
|
|
|
|
|
154
|
|
|
# Sanity check for the most common case |
|
155
|
|
|
if (self._model and isinstance(self._model, Model) and |
|
156
|
|
|
isinstance(self.algorithm, GradientDescent)): |
|
157
|
|
|
if not (set(self._model.get_parameter_dict().values()) == |
|
158
|
|
|
set(self.algorithm.parameters)): |
|
159
|
|
|
logger.warning("different parameters for model and algorithm") |
|
160
|
|
|
|
|
161
|
|
|
with change_recursion_limit(config.recursion_limit): |
|
162
|
|
|
self.original_sigint_handler = signal.signal( |
|
163
|
|
|
signal.SIGINT, self._handle_epoch_interrupt) |
|
164
|
|
|
self.original_sigterm_handler = signal.signal( |
|
165
|
|
|
signal.SIGTERM, self._handle_batch_interrupt) |
|
166
|
|
|
try: |
|
167
|
|
|
logger.info("Entered the main loop") |
|
168
|
|
|
if not self.status['training_started']: |
|
169
|
|
|
for extension in self.extensions: |
|
170
|
|
|
extension.main_loop = self |
|
171
|
|
|
self._run_extensions('before_training') |
|
172
|
|
|
with Timer('initialization', self.profile): |
|
173
|
|
|
self.algorithm.initialize() |
|
174
|
|
|
self.status['training_started'] = True |
|
175
|
|
|
# We can not write "else:" here because extensions |
|
176
|
|
|
# called "before_training" could have changed the status |
|
177
|
|
|
# of the main loop. |
|
178
|
|
|
if self.log.status['iterations_done'] > 0: |
|
179
|
|
|
self.log.resume() |
|
180
|
|
|
self._run_extensions('on_resumption') |
|
181
|
|
|
self.status['epoch_interrupt_received'] = False |
|
182
|
|
|
self.status['batch_interrupt_received'] = False |
|
183
|
|
|
with Timer('training', self.profile): |
|
184
|
|
|
while self._run_epoch(): |
|
185
|
|
|
pass |
|
186
|
|
|
except TrainingFinish: |
|
187
|
|
|
self.log.current_row['training_finished'] = True |
|
188
|
|
|
except Exception as e: |
|
|
|
|
|
|
189
|
|
|
self._restore_signal_handlers() |
|
190
|
|
|
self.log.current_row['got_exception'] = traceback.format_exc() |
|
191
|
|
|
logger.error("Error occured during training." + error_message) |
|
192
|
|
|
try: |
|
193
|
|
|
self._run_extensions('on_error', e) |
|
194
|
|
|
except Exception: |
|
|
|
|
|
|
195
|
|
|
logger.error(traceback.format_exc()) |
|
196
|
|
|
logger.error("Error occured when running extensions." + |
|
197
|
|
|
error_in_error_handling_message) |
|
198
|
|
|
reraise_as(e) |
|
199
|
|
|
finally: |
|
200
|
|
|
self._restore_signal_handlers() |
|
201
|
|
|
if self.log.current_row.get('training_finished', False): |
|
202
|
|
|
self._run_extensions('after_training') |
|
203
|
|
|
if config.profile: |
|
204
|
|
|
self.profile.report() |
|
205
|
|
|
|
|
206
|
|
|
def find_extension(self, name): |
|
207
|
|
|
"""Find an extension with a given name. |
|
208
|
|
|
|
|
209
|
|
|
Parameters |
|
210
|
|
|
---------- |
|
211
|
|
|
name : str |
|
212
|
|
|
The name of the extension looked for. |
|
213
|
|
|
|
|
214
|
|
|
Notes |
|
215
|
|
|
----- |
|
216
|
|
|
Will crash if there no or several extension found. |
|
217
|
|
|
|
|
218
|
|
|
""" |
|
219
|
|
|
return unpack([extension for extension in self.extensions |
|
220
|
|
|
if extension.name == name], singleton=True) |
|
221
|
|
|
|
|
222
|
|
|
def _run_epoch(self): |
|
223
|
|
|
if not self.status.get('epoch_started', False): |
|
224
|
|
|
try: |
|
225
|
|
|
self.log.status['received_first_batch'] = False |
|
226
|
|
|
self.epoch_iterator = (self.data_stream. |
|
227
|
|
|
get_epoch_iterator(as_dict=True)) |
|
228
|
|
|
except StopIteration: |
|
229
|
|
|
return False |
|
230
|
|
|
self.status['epoch_started'] = True |
|
231
|
|
|
self._run_extensions('before_epoch') |
|
232
|
|
|
with Timer('epoch', self.profile): |
|
233
|
|
|
while self._run_iteration(): |
|
234
|
|
|
pass |
|
235
|
|
|
self.status['epoch_started'] = False |
|
236
|
|
|
self.status['epochs_done'] += 1 |
|
237
|
|
|
# Log might not allow mutating objects, so use += instead of append |
|
238
|
|
|
self.status['_epoch_ends'] += [self.status['iterations_done']] |
|
239
|
|
|
self._run_extensions('after_epoch') |
|
240
|
|
|
self._check_finish_training('epoch') |
|
241
|
|
|
return True |
|
242
|
|
|
|
|
243
|
|
|
def _run_iteration(self): |
|
244
|
|
|
try: |
|
245
|
|
|
with Timer('read_data', self.profile): |
|
246
|
|
|
batch = next(self.epoch_iterator) |
|
247
|
|
|
except StopIteration: |
|
248
|
|
|
if not self.log.status['received_first_batch']: |
|
249
|
|
|
reraise_as(ValueError("epoch iterator yielded zero batches")) |
|
250
|
|
|
return False |
|
251
|
|
|
self.log.status['received_first_batch'] = True |
|
252
|
|
|
self._run_extensions('before_batch', batch) |
|
253
|
|
|
with Timer('train', self.profile): |
|
254
|
|
|
self.algorithm.process_batch(batch) |
|
255
|
|
|
self.status['iterations_done'] += 1 |
|
256
|
|
|
self._run_extensions('after_batch', batch) |
|
257
|
|
|
self._check_finish_training('batch') |
|
258
|
|
|
return True |
|
259
|
|
|
|
|
260
|
|
|
def _run_extensions(self, method_name, *args): |
|
261
|
|
|
with Timer(method_name, self.profile): |
|
262
|
|
|
for extension in self.extensions: |
|
263
|
|
|
with Timer(type(extension).__name__, self.profile): |
|
264
|
|
|
extension.dispatch(CallbackName(method_name), *args) |
|
265
|
|
|
|
|
266
|
|
|
def _check_finish_training(self, level): |
|
267
|
|
|
"""Checks whether the current training should be terminated. |
|
268
|
|
|
|
|
269
|
|
|
Parameters |
|
270
|
|
|
---------- |
|
271
|
|
|
level : {'epoch', 'batch'} |
|
272
|
|
|
The level at which this check was performed. In some cases, we |
|
273
|
|
|
only want to quit after completing the remained of the epoch. |
|
274
|
|
|
|
|
275
|
|
|
""" |
|
276
|
|
|
# In case when keyboard interrupt is handled right at the end of |
|
277
|
|
|
# the iteration the corresponding log record can be found only in |
|
278
|
|
|
# the previous row. |
|
279
|
|
|
if (self.log.current_row.get('training_finish_requested', False) or |
|
280
|
|
|
self.status.get('batch_interrupt_received', False)): |
|
281
|
|
|
raise TrainingFinish |
|
282
|
|
|
if (level == 'epoch' and |
|
283
|
|
|
self.status.get('epoch_interrupt_received', False)): |
|
284
|
|
|
raise TrainingFinish |
|
285
|
|
|
|
|
286
|
|
|
def _handle_epoch_interrupt(self, signal_number, frame): |
|
|
|
|
|
|
287
|
|
|
# Try to complete the current epoch if user presses CTRL + C |
|
288
|
|
|
logger.warning('Received epoch interrupt signal.' + |
|
289
|
|
|
epoch_interrupt_message) |
|
290
|
|
|
signal.signal(signal.SIGINT, self._handle_batch_interrupt) |
|
291
|
|
|
self.log.current_row['epoch_interrupt_received'] = True |
|
292
|
|
|
# Add a record to the status. Unlike the log record it will be |
|
293
|
|
|
# easy to access at later iterations. |
|
294
|
|
|
self.status['epoch_interrupt_received'] = True |
|
295
|
|
|
|
|
296
|
|
|
def _handle_batch_interrupt(self, signal_number, frame): |
|
|
|
|
|
|
297
|
|
|
# After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch |
|
298
|
|
|
self._restore_signal_handlers() |
|
299
|
|
|
logger.warning('Received batch interrupt signal.' + |
|
300
|
|
|
batch_interrupt_message) |
|
301
|
|
|
self.log.current_row['batch_interrupt_received'] = True |
|
302
|
|
|
# Add a record to the status. Unlike the log record it will be |
|
303
|
|
|
# easy to access at later iterations. |
|
304
|
|
|
self.status['batch_interrupt_received'] = True |
|
305
|
|
|
|
|
306
|
|
|
def _restore_signal_handlers(self): |
|
307
|
|
|
signal.signal(signal.SIGINT, self.original_sigint_handler) |
|
308
|
|
|
signal.signal(signal.SIGTERM, self.original_sigterm_handler) |
|
309
|
|
|
|
|
310
|
|
|
|
|
311
|
|
|
class TrainingFinish(Exception): |
|
312
|
|
|
"""An exception raised when a finish request is found in the log.""" |
|
313
|
|
|
pass |
|
314
|
|
|
|
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.