1 | """The event-based main loop of Blocks.""" |
||
2 | import signal |
||
3 | import logging |
||
4 | import traceback |
||
5 | |||
6 | from blocks.config import config |
||
7 | from blocks.log import BACKENDS |
||
8 | from blocks.utils import reraise_as, unpack, change_recursion_limit |
||
9 | from blocks.utils.profile import Profile, Timer |
||
10 | from blocks.algorithms import GradientDescent |
||
11 | from blocks.extensions import CallbackName |
||
12 | from blocks.model import Model |
||
13 | |||
14 | logger = logging.getLogger(__name__) |
||
15 | |||
16 | error_message = """ |
||
0 ignored issues
–
show
|
|||
17 | |||
18 | Blocks will attempt to run `on_error` extensions, potentially saving data, \ |
||
19 | before exiting and reraising the error. Note that the usual `after_training` \ |
||
20 | extensions will *not* be run. The original error will be re-raised and also \ |
||
21 | stored in the training log. Press CTRL + C to halt Blocks immediately.""" |
||
22 | |||
23 | error_in_error_handling_message = """ |
||
0 ignored issues
–
show
The name
error_in_error_handling_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
24 | |||
25 | Blocks will now exit. The remaining `on_error` extensions will not be run.""" |
||
26 | |||
27 | |||
28 | epoch_interrupt_message = """ |
||
0 ignored issues
–
show
The name
epoch_interrupt_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
29 | |||
30 | Blocks will complete this epoch of training and run extensions \ |
||
31 | before exiting. If you do not want to complete this epoch, press CTRL + C \ |
||
32 | again to stop training after the current batch.""" |
||
33 | |||
34 | batch_interrupt_message = """ |
||
0 ignored issues
–
show
The name
batch_interrupt_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
35 | |||
36 | Blocks will complete the current batch and run extensions before exiting. If \ |
||
37 | you do not want to complete this batch, press CTRL + C again. WARNING: Note \ |
||
38 | that this will end training immediately, and extensions that e.g. save your \ |
||
39 | training progress won't be run.""" |
||
40 | |||
41 | no_model_message = """ |
||
0 ignored issues
–
show
The name
no_model_message does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$ ).
This check looks for invalid names for a range of different identifiers. You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements. If your project includes a Pylint configuration file, the settings contained in that file take precedence. To find out more about Pylint, please refer to their site.
Loading history...
|
|||
42 | |||
43 | A possible reason: one of your extensions requires the main loop to have \ |
||
44 | a model. Check documentation of your extensions.""" |
||
45 | |||
46 | |||
47 | class MainLoop(object): |
||
48 | """The standard main loop of Blocks. |
||
49 | |||
50 | In the `MainLoop` a model is trained by a training algorithm using data |
||
51 | extracted from a data stream. This process is scrupulously documented |
||
52 | in a log object. |
||
53 | |||
54 | The `MainLoop` itself does very little: only fetching the data from the |
||
55 | data stream and feeding it to the algorithm. It expects the extensions |
||
56 | to do most of the job. A respective callback of every extension is |
||
57 | called at every stage of training. The extensions should communicate |
||
58 | between themselves and with the main loop object by means of making |
||
59 | records in the log. For instance in order to stop the training |
||
60 | procedure an extension can make a record |
||
61 | `training_finish_requested=True` in the log. The main loop checks for |
||
62 | such a record after every batch and every epoch and terminates when |
||
63 | finds it. |
||
64 | |||
65 | The `MainLoop` also handles interruption signal SIGINT for you (e.g. |
||
66 | the one program receives when you press Ctrl + C). It notes this event |
||
67 | in the log and at the next iteration or epoch end the main loop will |
||
68 | be gracefully finished, with calling all necessary extension callbacks |
||
69 | and waiting until they finish. |
||
70 | |||
71 | Parameters |
||
72 | ---------- |
||
73 | algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm` |
||
74 | The training algorithm. |
||
75 | data_stream : instance of :class:`.DataStream`. |
||
76 | The data stream. Should support :class:`AbstractDataStream` |
||
77 | interface from Fuel. |
||
78 | model : instance of :class:`.ComputationGraph`, optional |
||
79 | An annotated computation graph, typically represented |
||
80 | by :class:`ComputationGraph` or :class:`Model` object. The main |
||
81 | loop object uses the model only for optional sanity checks, it is |
||
82 | here mainly for the main loop extensions. |
||
83 | log : instance of :class:`.TrainingLog`, optional |
||
84 | The log. When not given, a :class:`.TrainingLog` is created. |
||
85 | log_backend : str |
||
86 | The backend to use for the log. Currently `python` and `sqlite` are |
||
87 | available. If not given, `config.log_backend` will be used. Ignored |
||
88 | if `log` is passed. |
||
89 | extensions : list of :class:`.TrainingExtension` instances |
||
90 | The training extensions. Will be called in the same order as given |
||
91 | here. |
||
92 | |||
93 | """ |
||
94 | def __init__(self, algorithm, data_stream, model=None, log=None, |
||
95 | log_backend=None, extensions=None): |
||
96 | if log is None: |
||
97 | if log_backend is None: |
||
98 | log_backend = config.log_backend |
||
99 | log = BACKENDS[log_backend]() |
||
100 | if extensions is None: |
||
101 | extensions = [] |
||
102 | |||
103 | self.data_stream = data_stream |
||
104 | self.epoch_iterator = None |
||
105 | self.algorithm = algorithm |
||
106 | self.log = log |
||
107 | self.extensions = extensions |
||
108 | |||
109 | self.profile = Profile() |
||
110 | |||
111 | self._model = model |
||
112 | |||
113 | self.status['training_started'] = False |
||
114 | self.status['epoch_started'] = False |
||
115 | self.status['epoch_interrupt_received'] = False |
||
116 | self.status['batch_interrupt_received'] = False |
||
117 | |||
118 | @property |
||
119 | def model(self): |
||
120 | if not self._model: |
||
121 | raise AttributeError("no model in this main loop" + |
||
122 | no_model_message) |
||
123 | return self._model |
||
124 | |||
125 | @property |
||
126 | def iteration_state(self): |
||
127 | """Quick access to the (data stream, epoch iterator) pair.""" |
||
128 | return (self.data_stream, self.epoch_iterator) |
||
129 | |||
130 | @iteration_state.setter |
||
131 | def iteration_state(self, value): |
||
132 | (self.data_stream, self.epoch_iterator) = value |
||
133 | |||
134 | @property |
||
135 | def status(self): |
||
136 | """A shortcut for `self.log.status`.""" |
||
137 | return self.log.status |
||
138 | |||
139 | def run(self): |
||
140 | """Starts the main loop. |
||
141 | |||
142 | The main loop ends when a training extension makes |
||
143 | a `training_finish_requested` record in the log. |
||
144 | |||
145 | """ |
||
146 | # This should do nothing if the user has already configured |
||
147 | # logging, and will it least enable error messages otherwise. |
||
148 | logging.basicConfig() |
||
149 | |||
150 | # If this is resumption from a checkpoint, it is crucial to |
||
151 | # reset `profile.current`. Otherwise, it simply does not hurt. |
||
152 | self.profile.current = [] |
||
153 | |||
154 | # Sanity check for the most common case |
||
155 | if (self._model and isinstance(self._model, Model) and |
||
156 | isinstance(self.algorithm, GradientDescent)): |
||
157 | if not (set(self._model.get_parameter_dict().values()) == |
||
158 | set(self.algorithm.parameters)): |
||
159 | logger.warning("different parameters for model and algorithm") |
||
160 | |||
161 | with change_recursion_limit(config.recursion_limit): |
||
162 | self.original_sigint_handler = signal.signal( |
||
163 | signal.SIGINT, self._handle_epoch_interrupt) |
||
164 | self.original_sigterm_handler = signal.signal( |
||
165 | signal.SIGTERM, self._handle_batch_interrupt) |
||
166 | try: |
||
167 | logger.info("Entered the main loop") |
||
168 | if not self.status['training_started']: |
||
169 | for extension in self.extensions: |
||
170 | extension.main_loop = self |
||
171 | self._run_extensions('before_training') |
||
172 | with Timer('initialization', self.profile): |
||
173 | self.algorithm.initialize() |
||
174 | self.status['training_started'] = True |
||
175 | # We can not write "else:" here because extensions |
||
176 | # called "before_training" could have changed the status |
||
177 | # of the main loop. |
||
178 | if self.log.status['iterations_done'] > 0: |
||
179 | self.log.resume() |
||
180 | self._run_extensions('on_resumption') |
||
181 | self.status['epoch_interrupt_received'] = False |
||
182 | self.status['batch_interrupt_received'] = False |
||
183 | with Timer('training', self.profile): |
||
184 | while self._run_epoch(): |
||
185 | pass |
||
186 | except TrainingFinish: |
||
187 | self.log.current_row['training_finished'] = True |
||
188 | except Exception as e: |
||
189 | self._restore_signal_handlers() |
||
190 | self.log.current_row['got_exception'] = traceback.format_exc() |
||
191 | logger.error("Error occured during training." + error_message) |
||
192 | try: |
||
193 | self._run_extensions('on_error', e) |
||
194 | except Exception: |
||
195 | logger.error(traceback.format_exc()) |
||
196 | logger.error("Error occured when running extensions." + |
||
197 | error_in_error_handling_message) |
||
198 | reraise_as(e) |
||
199 | finally: |
||
200 | self._restore_signal_handlers() |
||
201 | if self.log.current_row.get('training_finished', False): |
||
202 | self._run_extensions('after_training') |
||
203 | if config.profile: |
||
204 | self.profile.report() |
||
205 | |||
206 | def find_extension(self, name): |
||
207 | """Find an extension with a given name. |
||
208 | |||
209 | Parameters |
||
210 | ---------- |
||
211 | name : str |
||
212 | The name of the extension looked for. |
||
213 | |||
214 | Notes |
||
215 | ----- |
||
216 | Will crash if there no or several extension found. |
||
217 | |||
218 | """ |
||
219 | return unpack([extension for extension in self.extensions |
||
220 | if extension.name == name], singleton=True) |
||
221 | |||
222 | def _run_epoch(self): |
||
223 | if not self.status.get('epoch_started', False): |
||
224 | try: |
||
225 | self.log.status['received_first_batch'] = False |
||
226 | self.epoch_iterator = (self.data_stream. |
||
227 | get_epoch_iterator(as_dict=True)) |
||
228 | except StopIteration: |
||
229 | return False |
||
230 | self.status['epoch_started'] = True |
||
231 | self._run_extensions('before_epoch') |
||
232 | with Timer('epoch', self.profile): |
||
233 | while self._run_iteration(): |
||
234 | pass |
||
235 | self.status['epoch_started'] = False |
||
236 | self.status['epochs_done'] += 1 |
||
237 | # Log might not allow mutating objects, so use += instead of append |
||
238 | self.status['_epoch_ends'] += [self.status['iterations_done']] |
||
239 | self._run_extensions('after_epoch') |
||
240 | self._check_finish_training('epoch') |
||
241 | return True |
||
242 | |||
243 | def _run_iteration(self): |
||
244 | try: |
||
245 | with Timer('read_data', self.profile): |
||
246 | batch = next(self.epoch_iterator) |
||
247 | except StopIteration: |
||
248 | if not self.log.status['received_first_batch']: |
||
249 | reraise_as(ValueError("epoch iterator yielded zero batches")) |
||
250 | return False |
||
251 | self.log.status['received_first_batch'] = True |
||
252 | self._run_extensions('before_batch', batch) |
||
253 | with Timer('train', self.profile): |
||
254 | self.algorithm.process_batch(batch) |
||
255 | self.status['iterations_done'] += 1 |
||
256 | self._run_extensions('after_batch', batch) |
||
257 | self._check_finish_training('batch') |
||
258 | return True |
||
259 | |||
260 | def _run_extensions(self, method_name, *args): |
||
261 | with Timer(method_name, self.profile): |
||
262 | for extension in self.extensions: |
||
263 | with Timer(type(extension).__name__, self.profile): |
||
264 | extension.dispatch(CallbackName(method_name), *args) |
||
265 | |||
266 | def _check_finish_training(self, level): |
||
267 | """Checks whether the current training should be terminated. |
||
268 | |||
269 | Parameters |
||
270 | ---------- |
||
271 | level : {'epoch', 'batch'} |
||
272 | The level at which this check was performed. In some cases, we |
||
273 | only want to quit after completing the remained of the epoch. |
||
274 | |||
275 | """ |
||
276 | # In case when keyboard interrupt is handled right at the end of |
||
277 | # the iteration the corresponding log record can be found only in |
||
278 | # the previous row. |
||
279 | if (self.log.current_row.get('training_finish_requested', False) or |
||
280 | self.status.get('batch_interrupt_received', False)): |
||
281 | raise TrainingFinish |
||
282 | if (level == 'epoch' and |
||
283 | self.status.get('epoch_interrupt_received', False)): |
||
284 | raise TrainingFinish |
||
285 | |||
286 | def _handle_epoch_interrupt(self, signal_number, frame): |
||
287 | # Try to complete the current epoch if user presses CTRL + C |
||
288 | logger.warning('Received epoch interrupt signal.' + |
||
289 | epoch_interrupt_message) |
||
290 | signal.signal(signal.SIGINT, self._handle_batch_interrupt) |
||
291 | self.log.current_row['epoch_interrupt_received'] = True |
||
292 | # Add a record to the status. Unlike the log record it will be |
||
293 | # easy to access at later iterations. |
||
294 | self.status['epoch_interrupt_received'] = True |
||
295 | |||
296 | def _handle_batch_interrupt(self, signal_number, frame): |
||
297 | # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch |
||
298 | self._restore_signal_handlers() |
||
299 | logger.warning('Received batch interrupt signal.' + |
||
300 | batch_interrupt_message) |
||
301 | self.log.current_row['batch_interrupt_received'] = True |
||
302 | # Add a record to the status. Unlike the log record it will be |
||
303 | # easy to access at later iterations. |
||
304 | self.status['batch_interrupt_received'] = True |
||
305 | |||
306 | def _restore_signal_handlers(self): |
||
307 | signal.signal(signal.SIGINT, self.original_sigint_handler) |
||
308 | signal.signal(signal.SIGTERM, self.original_sigterm_handler) |
||
309 | |||
310 | |||
311 | class TrainingFinish(Exception): |
||
312 | """An exception raised when a finish request is found in the log.""" |
||
313 | pass |
||
314 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.