| Total Complexity | 51 |
| Total Lines | 261 |
| Duplicated Lines | 0 % |
Complex classes like MainLoop often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | """The event-based main loop of Blocks.""" |
||
| 47 | class MainLoop(object): |
||
| 48 | """The standard main loop of Blocks. |
||
| 49 | |||
| 50 | In the `MainLoop` a model is trained by a training algorithm using data |
||
| 51 | extracted from a data stream. This process is scrupulously documented |
||
| 52 | in a log object. |
||
| 53 | |||
| 54 | The `MainLoop` itself does very little: only fetching the data from the |
||
| 55 | data stream and feeding it to the algorithm. It expects the extensions |
||
| 56 | to do most of the job. A respective callback of every extension is |
||
| 57 | called at every stage of training. The extensions should communicate |
||
| 58 | between themselves and with the main loop object by means of making |
||
| 59 | records in the log. For instance in order to stop the training |
||
| 60 | procedure an extension can make a record |
||
| 61 | `training_finish_requested=True` in the log. The main loop checks for |
||
| 62 | such a record after every batch and every epoch and terminates when |
||
| 63 | finds it. |
||
| 64 | |||
| 65 | The `MainLoop` also handles interruption signal SIGINT for you (e.g. |
||
| 66 | the one program receives when you press Ctrl + C). It notes this event |
||
| 67 | in the log and at the next iteration or epoch end the main loop will |
||
| 68 | be gracefully finished, with calling all necessary extension callbacks |
||
| 69 | and waiting until they finish. |
||
| 70 | |||
| 71 | Parameters |
||
| 72 | ---------- |
||
| 73 | algorithm : instance of :class:`~blocks.algorithms.TrainingAlgorithm` |
||
| 74 | The training algorithm. |
||
| 75 | data_stream : instance of :class:`.DataStream`. |
||
| 76 | The data stream. Should support :class:`AbstractDataStream` |
||
| 77 | interface from Fuel. |
||
| 78 | model : instance of :class:`.ComputationGraph`, optional |
||
| 79 | An annotated computation graph, typically represented |
||
| 80 | by :class:`ComputationGraph` or :class:`Model` object. The main |
||
| 81 | loop object uses the model only for optional sanity checks, it is |
||
| 82 | here mainly for the main loop extensions. |
||
| 83 | log : instance of :class:`.TrainingLog`, optional |
||
| 84 | The log. When not given, a :class:`.TrainingLog` is created. |
||
| 85 | log_backend : str |
||
| 86 | The backend to use for the log. Currently `python` and `sqlite` are |
||
| 87 | available. If not given, `config.log_backend` will be used. Ignored |
||
| 88 | if `log` is passed. |
||
| 89 | extensions : list of :class:`.TrainingExtension` instances |
||
| 90 | The training extensions. Will be called in the same order as given |
||
| 91 | here. |
||
| 92 | |||
| 93 | """ |
||
| 94 | def __init__(self, algorithm, data_stream, model=None, log=None, |
||
| 95 | log_backend=None, extensions=None): |
||
| 96 | if log is None: |
||
| 97 | if log_backend is None: |
||
| 98 | log_backend = config.log_backend |
||
| 99 | log = BACKENDS[log_backend]() |
||
| 100 | if extensions is None: |
||
| 101 | extensions = [] |
||
| 102 | |||
| 103 | self.data_stream = data_stream |
||
| 104 | self.algorithm = algorithm |
||
| 105 | self.log = log |
||
| 106 | self.extensions = extensions |
||
| 107 | |||
| 108 | self.profile = Profile() |
||
| 109 | |||
| 110 | self._model = model |
||
| 111 | |||
| 112 | self.status['training_started'] = False |
||
| 113 | self.status['epoch_started'] = False |
||
| 114 | self.status['epoch_interrupt_received'] = False |
||
| 115 | self.status['batch_interrupt_received'] = False |
||
| 116 | |||
| 117 | @property |
||
| 118 | def model(self): |
||
| 119 | if not self._model: |
||
| 120 | raise AttributeError("no model in this main loop" + |
||
| 121 | no_model_message) |
||
| 122 | return self._model |
||
| 123 | |||
| 124 | @property |
||
| 125 | def iteration_state(self): |
||
| 126 | """Quick access to the (data stream, epoch iterator) pair.""" |
||
| 127 | return (self.data_stream, self.epoch_iterator) |
||
| 128 | |||
| 129 | @iteration_state.setter |
||
| 130 | def iteration_state(self, value): |
||
| 131 | (self.data_stream, self.epoch_iterator) = value |
||
| 132 | |||
| 133 | @property |
||
| 134 | def status(self): |
||
| 135 | """A shortcut for `self.log.status`.""" |
||
| 136 | return self.log.status |
||
| 137 | |||
| 138 | def run(self): |
||
| 139 | """Starts the main loop. |
||
| 140 | |||
| 141 | The main loop ends when a training extension makes |
||
| 142 | a `training_finish_requested` record in the log. |
||
| 143 | |||
| 144 | """ |
||
| 145 | # This should do nothing if the user has already configured |
||
| 146 | # logging, and will it least enable error messages otherwise. |
||
| 147 | logging.basicConfig() |
||
| 148 | |||
| 149 | # If this is resumption from a checkpoint, it is crucial to |
||
| 150 | # reset `profile.current`. Otherwise, it simply does not hurt. |
||
| 151 | self.profile.current = [] |
||
| 152 | |||
| 153 | # Sanity check for the most common case |
||
| 154 | if (self._model and isinstance(self._model, Model) and |
||
| 155 | isinstance(self.algorithm, DifferentiableCostMinimizer)): |
||
| 156 | if not (set(self._model.get_parameter_dict().values()) == |
||
| 157 | set(self.algorithm.parameters)): |
||
| 158 | logger.warning("different parameters for model and algorithm") |
||
| 159 | |||
| 160 | with change_recursion_limit(config.recursion_limit): |
||
| 161 | self.original_sigint_handler = signal.signal( |
||
| 162 | signal.SIGINT, self._handle_epoch_interrupt) |
||
| 163 | self.original_sigterm_handler = signal.signal( |
||
| 164 | signal.SIGTERM, self._handle_batch_interrupt) |
||
| 165 | try: |
||
| 166 | logger.info("Entered the main loop") |
||
| 167 | if not self.status['training_started']: |
||
| 168 | for extension in self.extensions: |
||
| 169 | extension.main_loop = self |
||
| 170 | self._run_extensions('before_training') |
||
| 171 | with Timer('initialization', self.profile): |
||
| 172 | self.algorithm.initialize() |
||
| 173 | self.status['training_started'] = True |
||
| 174 | # We can not write "else:" here because extensions |
||
| 175 | # called "before_training" could have changed the status |
||
| 176 | # of the main loop. |
||
| 177 | if self.log.status['iterations_done'] > 0: |
||
| 178 | self.log.resume() |
||
| 179 | self._run_extensions('on_resumption') |
||
| 180 | self.status['epoch_interrupt_received'] = False |
||
| 181 | self.status['batch_interrupt_received'] = False |
||
| 182 | with Timer('training', self.profile): |
||
| 183 | while self._run_epoch(): |
||
| 184 | pass |
||
| 185 | except TrainingFinish: |
||
| 186 | self.log.current_row['training_finished'] = True |
||
| 187 | except Exception as e: |
||
| 188 | self._restore_signal_handlers() |
||
| 189 | self.log.current_row['got_exception'] = traceback.format_exc() |
||
| 190 | logger.error("Error occured during training." + error_message) |
||
| 191 | try: |
||
| 192 | self._run_extensions('on_error') |
||
| 193 | except Exception: |
||
| 194 | logger.error(traceback.format_exc()) |
||
| 195 | logger.error("Error occured when running extensions." + |
||
| 196 | error_in_error_handling_message) |
||
| 197 | reraise_as(e) |
||
| 198 | finally: |
||
| 199 | self._restore_signal_handlers() |
||
| 200 | if self.log.current_row.get('training_finished', False): |
||
| 201 | self._run_extensions('after_training') |
||
| 202 | if config.profile: |
||
| 203 | self.profile.report() |
||
| 204 | |||
| 205 | def find_extension(self, name): |
||
| 206 | """Find an extension with a given name. |
||
| 207 | |||
| 208 | Parameters |
||
| 209 | ---------- |
||
| 210 | name : str |
||
| 211 | The name of the extension looked for. |
||
| 212 | |||
| 213 | Notes |
||
| 214 | ----- |
||
| 215 | Will crash if there no or several extension found. |
||
| 216 | |||
| 217 | """ |
||
| 218 | return unpack([extension for extension in self.extensions |
||
| 219 | if extension.name == name], singleton=True) |
||
| 220 | |||
| 221 | def _run_epoch(self): |
||
| 222 | if not self.status.get('epoch_started', False): |
||
| 223 | try: |
||
| 224 | self.log.status['received_first_batch'] = False |
||
| 225 | self.epoch_iterator = (self.data_stream. |
||
| 226 | get_epoch_iterator(as_dict=True)) |
||
| 227 | except StopIteration: |
||
| 228 | return False |
||
| 229 | self.status['epoch_started'] = True |
||
| 230 | self._run_extensions('before_epoch') |
||
| 231 | with Timer('epoch', self.profile): |
||
| 232 | while self._run_iteration(): |
||
| 233 | pass |
||
| 234 | self.status['epoch_started'] = False |
||
| 235 | self.status['epochs_done'] += 1 |
||
| 236 | # Log might not allow mutating objects, so use += instead of append |
||
| 237 | self.status['_epoch_ends'] += [self.status['iterations_done']] |
||
| 238 | self._run_extensions('after_epoch') |
||
| 239 | self._check_finish_training('epoch') |
||
| 240 | return True |
||
| 241 | |||
| 242 | def _run_iteration(self): |
||
| 243 | try: |
||
| 244 | with Timer('read_data', self.profile): |
||
| 245 | batch = next(self.epoch_iterator) |
||
| 246 | except StopIteration: |
||
| 247 | if not self.log.status['received_first_batch']: |
||
| 248 | reraise_as(ValueError("epoch iterator yielded zero batches")) |
||
| 249 | return False |
||
| 250 | self.log.status['received_first_batch'] = True |
||
| 251 | self._run_extensions('before_batch', batch) |
||
| 252 | with Timer('train', self.profile): |
||
| 253 | self.algorithm.process_batch(batch) |
||
| 254 | self.status['iterations_done'] += 1 |
||
| 255 | self._run_extensions('after_batch', batch) |
||
| 256 | self._check_finish_training('batch') |
||
| 257 | return True |
||
| 258 | |||
| 259 | def _run_extensions(self, method_name, *args): |
||
| 260 | with Timer(method_name, self.profile): |
||
| 261 | for extension in self.extensions: |
||
| 262 | with Timer(type(extension).__name__, self.profile): |
||
| 263 | extension.dispatch(CallbackName(method_name), *args) |
||
| 264 | |||
| 265 | def _check_finish_training(self, level): |
||
| 266 | """Checks whether the current training should be terminated. |
||
| 267 | |||
| 268 | Parameters |
||
| 269 | ---------- |
||
| 270 | level : {'epoch', 'batch'} |
||
| 271 | The level at which this check was performed. In some cases, we |
||
| 272 | only want to quit after completing the remained of the epoch. |
||
| 273 | |||
| 274 | """ |
||
| 275 | # In case when keyboard interrupt is handled right at the end of |
||
| 276 | # the iteration the corresponding log record can be found only in |
||
| 277 | # the previous row. |
||
| 278 | if (self.log.current_row.get('training_finish_requested', False) or |
||
| 279 | self.status.get('batch_interrupt_received', False)): |
||
| 280 | raise TrainingFinish |
||
| 281 | if (level == 'epoch' and |
||
| 282 | self.status.get('epoch_interrupt_received', False)): |
||
| 283 | raise TrainingFinish |
||
| 284 | |||
| 285 | def _handle_epoch_interrupt(self, signal_number, frame): |
||
| 286 | # Try to complete the current epoch if user presses CTRL + C |
||
| 287 | logger.warning('Received epoch interrupt signal.' + |
||
| 288 | epoch_interrupt_message) |
||
| 289 | signal.signal(signal.SIGINT, self._handle_batch_interrupt) |
||
| 290 | self.log.current_row['epoch_interrupt_received'] = True |
||
| 291 | # Add a record to the status. Unlike the log record it will be |
||
| 292 | # easy to access at later iterations. |
||
| 293 | self.status['epoch_interrupt_received'] = True |
||
| 294 | |||
| 295 | def _handle_batch_interrupt(self, signal_number, frame): |
||
| 296 | # After 2nd CTRL + C or SIGTERM signal (from cluster) finish batch |
||
| 297 | self._restore_signal_handlers() |
||
| 298 | logger.warning('Received batch interrupt signal.' + |
||
| 299 | batch_interrupt_message) |
||
| 300 | self.log.current_row['batch_interrupt_received'] = True |
||
| 301 | # Add a record to the status. Unlike the log record it will be |
||
| 302 | # easy to access at later iterations. |
||
| 303 | self.status['batch_interrupt_received'] = True |
||
| 304 | |||
| 305 | def _restore_signal_handlers(self): |
||
| 306 | signal.signal(signal.SIGINT, self.original_sigint_handler) |
||
| 307 | signal.signal(signal.SIGTERM, self.original_sigterm_handler) |
||
| 308 | |||
| 313 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.