1 | """The event-based main loop of Blocks.""" |
||
2 | import signal |
||
3 | import logging |
||
4 | import traceback |
||
5 | |||
6 | from blocks.config import config |
||
7 | from blocks.log import BACKENDS |
||
8 | from blocks.utils import reraise_as, unpack, change_recursion_limit |
||
9 | from blocks.utils.profile import Profile, Timer |
||
10 | from blocks.algorithms import GradientDescent |
||
11 | from blocks.extensions import CallbackName |
||
12 | from blocks.model import Model |
||
13 | |||
14 | logger = logging.getLogger(__name__) |
||
15 | |||
16 | error_message = """ |
||
17 | |||
18 | Blocks will attempt to run `on_error` extensions, potentially saving data, \ |
||
19 | before exiting and reraising the error. Note that the usual `after_training` \ |
||
20 | extensions will *not* be run. The original error will be re-raised and also \ |
||
21 | stored in the training log. Press CTRL + C to halt Blocks immediately.""" |
||
22 | |||
23 | error_in_error_handling_message = """ |
||
24 | |||
25 | Blocks will now exit. The remaining `on_error` extensions will not be run.""" |
||
26 | |||
27 | |||
28 | epoch_interrupt_message = """ |
||
29 | |||
30 | Blocks will complete this epoch of training and run extensions \ |
||
31 | before exiting. If you do not want to complete this epoch, press CTRL + C \ |
||
32 | again to stop training after the current batch.""" |
||
33 | |||
34 | batch_interrupt_message = """ |
||
35 | |||
36 | Blocks will complete the current batch and run extensions before exiting. If \ |
||
37 | you do not want to complete this batch, press CTRL + C again. WARNING: Note \ |
||
38 | that this will end training immediately, and extensions that e.g. save your \ |
||
39 | training progress won't be run.""" |
||
40 | |||
41 | no_model_message = """ |
||
42 | |||
43 | A possible reason: one of your extensions requires the main loop to have \ |
||
44 | a model. Check documentation of your extensions.""" |
||
45 | |||
46 | |||
47 | class MainLoop(object): |
||
48 | """The standard main loop of Blocks. |
||
49 | |||
50 | In the `MainLoop` a model is trained by a training algorithm using data |
||
51 | extracted from a data stream. This process is scrupulously documented |
||
52 | in a log object. |
||
53 | |||
54 | The `MainLoop` itself does very little: only fetching the data from the |
||
55 | data stream and feeding it to the algorithm. It expects the extensions |
||
56 | to do most of the job. A respective callback of every extension is |
||
57 | called at every stage of training. The extensions should communicate |
||
58 | between themselves and with the main loop object by means of making |
||
59 | records in the log. For instance in order to stop the training |
||
60 | procedure an extension can make a record |
||
61 | `training_finish_requested=True` in the log. The main loop checks for |
||
62 | such a record after every batch and every epoch and terminates when |
||
63 | finds it. |
||
64 | |||
65 | The `MainLoop` also handles interruption signal SIGINT for you (e.g. |
||
66 | the one program receives when you press Ctrl + C). It notes this event |
||
67 | in the log and at the next iteration or epoch end the main loop will |
||
68 | be gracefully finished, with calling all necessary extension callbacks |
||
69 | and waiting until they finish. |
||
70 | |||
71 | Parameters |
||
72 | ---------- |
||
73 | algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm` |
||
74 | The training algorithm. |
||
75 | data_stream : instance of :class:`.DataStream`. |
||
76 | The data stream. Should support :class:`AbstractDataStream` |
||
77 | interface from Fuel. |
||
78 | model : instance of :class:`.ComputationGraph`, optional |
||
79 | An annotated computation graph, typically represented |
||
80 | by :class:`ComputationGraph` or :class:`Model` object. The main |
||
81 | loop object uses the model only for optional sanity checks, it is |
||
82 | here mainly for the main loop extensions. |
||
83 | log : instance of :class:`.TrainingLog`, optional |
||
84 | The log. When not given, a :class:`.TrainingLog` is created. |
||
85 | log_backend : str |
||
86 | The backend to use for the log. Currently `python` and `sqlite` are |
||
87 | available. If not given, `config.log_backend` will be used. Ignored |
||
88 | if `log` is passed. |
||
89 | extensions : list of :class:`.TrainingExtension` instances |
||
90 | The training extensions. Will be called in the same order as given |
||
91 | here. |
||
92 | |||
93 | """ |
||
94 | def __init__(self, algorithm, data_stream, model=None, log=None, |
||
95 | log_backend=None, extensions=None): |
||
96 | if log is None: |
||
97 | if log_backend is None: |
||
98 | log_backend = config.log_backend |
||
99 | log = BACKENDS[log_backend]() |
||
100 | if extensions is None: |
||
101 | extensions = [] |
||
102 | |||
103 | self.data_stream = data_stream |
||
104 | self.epoch_iterator = None |
||
105 | self.algorithm = algorithm |
||
106 | self.log = log |
||
107 | self.extensions = extensions |
||
108 | |||
109 | self.profile = Profile() |
||
110 | |||
111 | self._model = model |
||
112 | |||
113 | self.status['training_started'] = False |
||
114 | self.status['epoch_started'] = False |
||
115 | self.status['epoch_interrupt_received'] = False |
||
116 | self.status['batch_interrupt_received'] = False |
||
117 | |||
118 | @property |
||
119 | def model(self): |
||
120 | if not self._model: |
||
121 | raise AttributeError("no model in this main loop" + |
||
122 | no_model_message) |
||
123 | return self._model |
||
124 | |||
125 | @property |
||
126 | def iteration_state(self): |
||
127 | """Quick access to the (data stream, epoch iterator) pair.""" |
||
128 | return (self.data_stream, self.epoch_iterator) |
||
129 | |||
130 | @iteration_state.setter |
||
131 | def iteration_state(self, value): |
||
132 | (self.data_stream, self.epoch_iterator) = value |
||
133 | |||
134 | @property |
||
135 | def status(self): |
||
136 | """A shortcut for `self.log.status`.""" |
||
137 | return self.log.status |
||
138 | |||
139 | def run(self): |
||
140 | """Starts the main loop. |
||
141 | |||
142 | The main loop ends when a training extension makes |
||
143 | a `training_finish_requested` record in the log. |
||
144 | |||
145 | """ |
||
146 | # This should do nothing if the user has already configured |
||
147 | # logging, and will it least enable error messages otherwise. |
||
148 | logging.basicConfig() |
||
149 | |||
150 | # If this is resumption from a checkpoint, it is crucial to |
||
151 | # reset `profile.current`. Otherwise, it simply does not hurt. |
||
152 | self.profile.current = [] |
||
153 | |||
154 | # Sanity check for the most common case |
||
155 | if (self._model and isinstance(self._model, Model) and |
||
156 | isinstance(self.algorithm, GradientDescent)): |
||
157 | if not (set(self._model.get_parameter_dict().values()) == |
||
158 | set(self.algorithm.parameters)): |
||
159 | logger.warning("different parameters for model and algorithm") |
||
160 | |||
161 | with change_recursion_limit(config.recursion_limit): |
||
162 | self.original_sigint_handler = signal.signal( |
||
163 | signal.SIGINT, self._handle_epoch_interrupt) |
||
164 | self.original_sigterm_handler = signal.signal( |
||
165 | signal.SIGTERM, self._handle_batch_interrupt) |
||
166 | try: |
||
167 | logger.info("Entered the main loop") |
||
168 | if not self.status['training_started']: |
||
169 | for extension in self.extensions: |
||
170 | extension.main_loop = self |
||
171 | self._run_extensions('before_training') |
||
172 | with Timer('initialization', self.profile): |
||
173 | self.algorithm.initialize() |
||
174 | self.status['training_started'] = True |
||
175 | # We can not write "else:" here because extensions |
||
176 | # called "before_training" could have changed the status |
||
177 | # of the main loop. |
||
178 | if self.log.status['iterations_done'] > 0: |
||
179 | self.log.resume() |
||
180 | self._run_extensions('on_resumption') |
||
181 | self.status['epoch_interrupt_received'] = False |
||
182 | self.status['batch_interrupt_received'] = False |
||
183 | with Timer('training', self.profile): |
||
184 | while self._run_epoch(): |
||
185 | pass |
||
186 | except TrainingFinish: |
||
187 | self.log.current_row['training_finished'] = True |
||
188 | except Exception as e: |
||
0 ignored issues
–
show
|
|||
189 | self._restore_signal_handlers() |
||
190 | self.log.current_row['got_exception'] = traceback.format_exc() |
||
191 | logger.error("Error occured during training." + error_message) |
||
192 | try: |
||
193 | self._run_extensions('on_error', e) |
||
194 | except Exception: |
||
0 ignored issues
–
show
Catching very general exceptions such as
Exception is usually not recommended.
Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed. So, unless you specifically plan to handle any error, consider adding a more specific exception.
Loading history...
|
|||
195 | logger.error(traceback.format_exc()) |
||
196 | logger.error("Error occured when running extensions." + |
||
197 | error_in_error_handling_message) |
||
198 | reraise_as(e) |
||
199 | finally: |
||
200 | self._restore_signal_handlers() |
||
201 | if self.log.current_row.get('training_finished', False): |
||
202 | self._run_extensions('after_training') |
||
203 | if config.profile: |
||
204 | self.profile.report() |
||
205 | |||
206 | def find_extension(self, name): |
||
207 | """Find an extension with a given name. |
||
208 | |||
209 | Parameters |
||
210 | ---------- |
||
211 | name : str |
||
212 | The name of the extension looked for. |
||
213 | |||
214 | Notes |
||
215 | ----- |
||
216 | Will crash if there no or several extension found. |
||
217 | |||
218 | """ |
||
219 | return unpack([extension for extension in self.extensions |
||
220 | if extension.name == name], singleton=True) |
||
221 | |||
222 | def _run_epoch(self): |
||
223 | if not self.status.get('epoch_started', False): |
||
224 | try: |
||
225 | self.log.status['received_first_batch'] = False |
||
226 | self.epoch_iterator = (self.data_stream. |
||
227 | get_epoch_iterator(as_dict=True)) |
||
228 | except StopIteration: |
||
229 | return False |
||
230 | self.status['epoch_started'] = True |
||
231 | self._run_extensions('before_epoch') |
||
232 | with Timer('epoch', self.profile): |
||
233 | while self._run_iteration(): |
||
234 | pass |
||
235 | self.status['epoch_started'] = False |
||
236 | self.status['epochs_done'] += 1 |
||
237 | # Log might not allow mutating objects, so use += instead of append |
||
238 | self.status['_epoch_ends'] += [self.status['iterations_done']] |
||
239 | self._run_extensions('after_epoch') |
||
240 | self._check_finish_training('epoch') |
||
241 | return True |
||
242 | |||
243 | def _run_iteration(self): |
||
244 | try: |
||
245 | with Timer('read_data', self.profile): |
||
246 | batch = next(self.epoch_iterator) |
||
247 | except StopIteration: |
||
248 | if not self.log.status['received_first_batch']: |
||
249 | reraise_as(ValueError("epoch iterator yielded zero batches")) |
||
250 | return False |
||
251 | self.log.status['received_first_batch'] = True |
||
252 | self._run_extensions('before_batch', batch) |
||
253 | with Timer('train', self.profile): |
||
254 | self.algorithm.process_batch(batch) |
||
255 | self.status['iterations_done'] += 1 |
||
256 | self._run_extensions('after_batch', batch) |
||
257 | self._check_finish_training('batch') |
||
258 | return True |
||
259 | |||
260 | def _run_extensions(self, method_name, *args): |
||
261 | with Timer(method_name, self.profile): |
||
262 | for extension in self.extensions: |
||
263 | with Timer(type(extension).__name__, self.profile): |
||
264 | extension.dispatch(CallbackName(method_name), *args) |
||
265 | |||
266 | def _check_finish_training(self, level): |
||
267 | """Checks whether the current training should be terminated. |
||
268 | |||
269 | Parameters |
||
270 | ---------- |
||
271 | level : {'epoch', 'batch'} |
||
272 | The level at which this check was performed. In some cases, we |
||
273 | only want to quit after completing the remained of the epoch. |
||
274 | |||
275 | """ |
||
276 | # In case when keyboard interrupt is handled right at the end of |
||
277 | # the iteration the corresponding log record can be found only in |
||
278 | # the previous row. |
||
279 | if (self.log.current_row.get('training_finish_requested', False) or |
||
280 | self.status.get('batch_interrupt_received', False)): |
||
281 | raise TrainingFinish |
||
282 | if (level == 'epoch' and |
||
283 | self.status.get('epoch_interrupt_received', False)): |
||
284 | raise TrainingFinish |
||
285 | |||
286 | def _handle_epoch_interrupt(self, signal_number, frame): |
||
287 | # Try to complete the current epoch if user presses CTRL + C |
||
288 | logger.warning('Received epoch interrupt signal.' + |
||
289 | epoch_interrupt_message) |
||
290 | signal.signal(signal.SIGINT, self._handle_batch_interrupt) |
||
291 | self.log.current_row['epoch_interrupt_received'] = True |
||
292 | # Add a record to the status. Unlike the log record it will be |
||
293 | # easy to access at later iterations. |
||
294 | self.status['epoch_interrupt_received'] = True |
||
295 | |||
296 | def _handle_batch_interrupt(self, signal_number, frame): |
||
297 | # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch |
||
298 | self._restore_signal_handlers() |
||
299 | logger.warning('Received batch interrupt signal.' + |
||
300 | batch_interrupt_message) |
||
301 | self.log.current_row['batch_interrupt_received'] = True |
||
302 | # Add a record to the status. Unlike the log record it will be |
||
303 | # easy to access at later iterations. |
||
304 | self.status['batch_interrupt_received'] = True |
||
305 | |||
306 | def _restore_signal_handlers(self): |
||
307 | signal.signal(signal.SIGINT, self.original_sigint_handler) |
||
308 | signal.signal(signal.SIGTERM, self.original_sigterm_handler) |
||
309 | |||
310 | |||
311 | class TrainingFinish(Exception): |
||
312 | """An exception raised when a finish request is found in the log.""" |
||
313 | pass |
||
314 |
Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.
So, unless you specifically plan to handle any error, consider adding a more specific exception.