Completed
Pull Request — master (#1109)
by David
04:48
created

CompositeExtension.run_super()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 3
rs 10
1
from __future__ import print_function
2
3
import logging
4
from abc import ABCMeta, abstractmethod
5
6
import progressbar
7
from six import add_metaclass
8
from toolz import first
9
10
logger = logging.getLogger(__name__)
11
12
13
def callback(func):
14
    func._is_callback = True
15
    return func
16
17
18
class TrainingExtension(object):
19
    """The base class for training extensions.
20
21
    An extension is a set of callbacks sharing a joint context that are
22
    invoked at certain stages of the training procedure. These callbacks
23
    typically add a certain functionality to the training procedure,
24
    e.g. running validation on auxiliary datasets or early stopping.
25
26
    Parameters
27
    ----------
28
    name : str, optional
29
        The name of the extension. The names are useful in order to
30
        distinguish between several extensions of the same type that
31
        belongs to the same main loop. By default the name is set to
32
        the name of the class.
33
34
    Attributes
35
    ----------
36
    main_loop : :class:`.MainLoop`
37
        The main loop to which the extension belongs.
38
    name : str
39
        The name of the extension.
40
41
    """
42
    def __init__(self, name=None):
43
        if not name:
44
            name = self.__class__.__name__
45
        self.name = name
46
47
    @property
48
    def main_loop(self):
49
        if not hasattr(self, '_main_loop'):
50
            raise ValueError("main loop must be assigned to extension first")
51
        return self._main_loop
52
53
    @main_loop.setter
54
    def main_loop(self, value):
55
        self._main_loop = value
56
57
    def dispatch(self, callback_name, *args):
58
        """Runs callback with the given name.
59
60
        The reason for having this method is to allow
61
        the descendants of the :class:`TrainingExtension` to intercept
62
        callback invocations and do something with them, e.g. block
63
        when certain condition does not hold. The default implementation
64
        simply invokes the callback by its name.
65
66
        """
67
        getattr(self, str(callback_name))(*args)
68
69
    @callback
70
    def on_resumption(self):
71
        """The callback invoked after training is resumed."""
72
        pass
73
74
    @callback
75
    def on_error(self):
76
        """The callback invoked when an error occurs."""
77
        pass
78
79
    @callback
80
    def before_training(self):
81
        """The callback invoked before training is started."""
82
        pass
83
84
    @callback
85
    def before_epoch(self):
86
        """The callback invoked before starting an epoch."""
87
        pass
88
89
    @callback
90
    def before_batch(self, batch):
91
        """The callback invoked before a batch is processed.
92
93
        Parameters
94
        ----------
95
        batch : object
96
            The data batch to be processed.
97
98
        """
99
        pass
100
101
    @callback
102
    def after_batch(self, batch):
103
        """The callback invoked after a batch is processed.
104
105
        Parameters
106
        ----------
107
        batch : object
108
            The data batch just processed.
109
110
        """
111
        pass
112
113
    @callback
114
    def after_epoch(self):
115
        """The callback invoked after an epoch is finished."""
116
        pass
117
118
    @callback
119
    def after_training(self):
120
        """The callback invoked after training is finished."""
121
        pass
122
123
    @callback
124
    def on_interrupt(self):
125
        """The callback invoked when training is interrupted."""
126
        pass
127
128
129
class CallbackName(str):
130
    """A name of a TrainingExtension callback.
131
132
    Raises
133
    ------
134
    :class:`TypeError` on comparison with a string which is not a name of
135
    TrainingExtension callback.
136
137
    """
138
    def __eq__(self, other):
139
        callback_names = [key for key, value
140
                          in TrainingExtension.__dict__.items()
141
                          if getattr(value, '_is_callback', False)]
142
        if other not in callback_names:
143
            raise TypeError("{} is not a valid callback.".format(other))
144
        return str(self) == other
145
146
147
class Predicate(object):
148
    def __init__(self, condition, num):
149
        self.condition = condition
150
        self.num = num
151
152
    def __call__(self, log):
153
        if self.condition.endswith('epochs'):
154
            entry = log.status['epochs_done']
155
        else:
156
            entry = log.status['iterations_done']
157
        if self.condition.startswith('every'):
158
            return entry % self.num == 0
159
        else:
160
            return entry == self.num
161
162
163
def has_done_epochs(log):
164
    return log.status['epochs_done'] == 0
165
166
167
def always_true(log):
168
    return True
169
170
171
@add_metaclass(ABCMeta)
172
class SimpleExtension(TrainingExtension):
173
    """A base class for simple extensions.
174
175
    All logic of simple extensions is concentrated in the method
176
    :meth:`do`.  This method is called when certain conditions are
177
    fulfilled. The user can manage the conditions by calling the
178
    `add_condition` method and by passing arguments to the constructor.  In
179
    addition to specifying when :meth:`do` is called, it is possible to
180
    specify additional arguments passed to :meth:`do` under different
181
    conditions.
182
183
    Parameters
184
    ----------
185
    before_training : bool
186
        If ``True``, :meth:`do` is invoked before training.
187
    before_first_epoch : bool
188
        If ``True``, :meth:`do` is invoked before the first epoch.
189
    before_epoch : bool
190
        If ``True``, :meth:`do` is invoked before every epoch.
191
    on_resumption : bool, optional
192
        If ``True``, :meth:`do` is invoked when training is resumed.
193
    on_interrupt : bool, optional
194
        If ``True``, :meth:`do` is invoked when training is interrupted.
195
    after_epoch : bool
196
        If ``True``, :meth:`do` is invoked after every epoch.
197
    after_batch: bool
198
        If ``True``, :meth:`do` is invoked after every batch.
199
    after_training : bool
200
        If ``True``, :meth:`do` is invoked after training.
201
    after_n_epochs : int, optional
202
        If not ``None``, :meth:`do` is invoked when `after_n_epochs`
203
        epochs are done.
204
    every_n_epochs : int, optional
205
        If not ``None``, :meth:`do` is invoked after every n-th epoch.
206
    after_n_batches : int, optional
207
        If not ``None``, :meth:`do` is invoked when `after_n_batches`
208
        batches are processed.
209
    every_n_batches : int, optional
210
        If not ``None``, :meth:`do` is invoked after every n-th batch.
211
212
    """
213
    BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
214
                                  "before_epoch", "on_resumption",
215
                                  "on_interrupt", "after_epoch",
216
                                  "after_batch", "after_training"])
217
218
    INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches",
219
                                  "every_n_epochs", "every_n_batches"])
220
221
    def __init__(self, **kwargs):
222
        self._conditions = []
223
        super_kwargs = {}
224
        trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS
225
        conditions = {}
226
        for key, value in kwargs.items():
227
            if key in trigger_keywords:
228
                conditions[key] = value
229
            else:
230
                super_kwargs[key] = value
231
        self.set_conditions(**conditions)
232
        super(SimpleExtension, self).__init__(**super_kwargs)
233
234
    def set_conditions(self, **kwargs):
235
        """Set the conditions for which this extension should be run.
236
237
        Parameters
238
        ----------
239
        See the :class:`SimpleExtension` docstring for a list of
240
        possible parameters.
241
242
        """
243
        self._conditions[:] = []
244
        predicates = {'before_first_epoch': has_done_epochs}
245
        conditions = {
246
            'before_first_epoch': 'before_epoch',
247
            'after_epoch': 'after_epoch',
248
            'after_batch': 'after_batch',
249
            'every_n_batches': 'after_batch',
250
            'every_n_epochs': 'after_epoch',
251
            'after_n_batches': 'after_batch',
252
            'after_n_epochs': 'after_epoch'
253
        }
254
        # Freeze the keys as a list so that we can safely modify kwargs.
255
        for key, value in kwargs.items():
256
            if value:
257
                if key in self.BOOLEAN_TRIGGERS:
258
                    self.add_condition([conditions.get(key, key)],
259
                                       predicate=predicates.get(key, None))
260
                elif key in self.INTEGER_TRIGGERS:
261
                    predicate = Predicate(key, value)
262
                    self.add_condition([conditions.get(key, key)],
263
                                       predicate=predicate)
264
                else:
265
                    raise KeyError("Invalid condition: {}".format(key))
266
        return self  # For chaining calls.
267
268
    def add_condition(self, callbacks_names, predicate=None, arguments=None):
269
        """Adds a condition under which a :meth:`do` is called.
270
271
        Parameters
272
        ----------
273
        callbacks_names : list of str
274
            The names of the callback in which the method.
275
        predicate : function
276
            A predicate function the main loop's log as the
277
            single parameter and returning ``True`` when the method
278
            should be called and ``False`` when should not. If ``None``,
279
            an always ``True`` predicate is used.
280
        arguments : iterable
281
            Additional arguments to be passed to :meth:`do`. They will
282
            be concatenated with the ones passed from the main loop
283
            (e.g. the batch in case of `after_epoch` callback).
284
285
        Returns
286
        -------
287
            The extension object (allow chaining calls)
288
289
        """
290
        if not isinstance(callbacks_names, (list, tuple)):
291
            raise ValueError("callbacks_names must be list or tuple.")
292
        for _callback_name in callbacks_names:
293
            if not arguments:
294
                arguments = []
295
            if not predicate:
296
                self._conditions.append((_callback_name, always_true,
297
                                        arguments))
298
            else:
299
                self._conditions.append((_callback_name, predicate,
300
                                        arguments))
301
        return self
302
303
    @abstractmethod
304
    def do(self, which_callback, *args):
305
        r"""Does the job of the training extension.
306
307
        Parameters
308
        ----------
309
        which_callback : str
310
            The name of the callback in the context of which :meth:`do` is
311
            run.
312
        \*args : tuple
313
            The arguments from the main loop concatenated with additional
314
            arguments from user.
315
316
        Notes
317
        -----
318
        Subclasses *must* accept additional positional arguments in their
319
        call signature for this method, even if they are unused.
320
321
        """
322
        pass
323
324
    def dispatch(self, callback_invoked, *from_main_loop):
325
        """Check conditions and call the :meth:`do` method.
326
327
        Also adds additional arguments if specified for a condition.
328
329
        .. todo::
330
331
            Add a check for a situation when several conditions are met
332
            at the same time and do something.
333
334
        """
335
        for callback_name, predicate, arguments in self._conditions:
336
            if (callback_name == callback_invoked and
337
                    predicate(self.main_loop.log)):
338
                self.do(callback_invoked, *(from_main_loop + tuple(arguments)))
339
340
    @staticmethod
341
    def parse_args(which_callback, args):
342
        """Separates :meth:`do` arguments coming from different sources.
343
344
        When a :meth:`do` method receives arguments from both the main
345
        loop (e.g. a batch) and the user, it often has to separate them.
346
        This method is the right tool to use.
347
348
        Parameters
349
        ----------
350
        which_callback : str
351
            The name of the callback.
352
        args : iterable
353
            The arguments.
354
355
        Returns
356
        -------
357
        from_main_loop : tuple
358
        from_user : tuple
359
360
        """
361
        args = tuple(args)
362
        if (which_callback == 'after_batch' or
363
                which_callback == 'before_batch'):
364
            return (args[0],), args[1:]
365
        return (), args
366
367
368
class CompositeExtension(SimpleExtension):
369
    """An extension that manages several other extensions.
370
371
    Parameters
372
    ----------
373
    sub_extensions : iterable
374
        An iterable collection of sub-extensions to manage.
375
    run_before_children : bool, optional
376
        Whether the container extension's own logic should
377
        be dispatched before that of the sub-extensions.
378
        If ``False``, the containing extension is dispatched last.
379
        Defaults to ``True``.
380
381
    Notes
382
    -----
383
    The main use case for this class is bundling together groups
384
    of extensions that are most commonly used in tandem, configured
385
    so as to interact with one another. Encapsulating this pattern
386
    in a single extension reduces boilerplate.
387
388
    Sub-extensions are dispatched in the order specified in
389
    ``sub_extensions``, on whatever triggers they are individually
390
    configured to respect.
391
392
    Sub-extensions may be run on different triggers than the containing
393
    extension; the trigger keywords passed to the constructor
394
    for this class only affect the outer extension's logic, and
395
    sub-extensions should be configured independently (possibly in
396
    a constructor for a subclass of :class:`CompositeExtension`).
397
398
    """
399
    def __init__(self, sub_extensions, run_before_children=True, **kwargs):
400
        self.sub_extensions = sub_extensions
401
        self.run_before_children = run_before_children
402
        super(CompositeExtension, self).__init__(**kwargs)
403
404
    def dispatch(self, callback_invoked, *from_main_loop):
405
        def run_super():
406
            super(CompositeExtension, self).dispatch(callback_invoked,
407
                                                     *from_main_loop)
408
        if self.run_before_children:
409
            run_super()
410
411
        for ext in self.sub_extensions:
412
            ext.dispatch(callback_invoked, *from_main_loop)
413
414
        if not self.run_before_children:
415
            run_super()
416
417
    @property
418
    def main_loop(self):
419
        return super(CompositeExtension, self).main_loop
420
421
    @main_loop.setter
422
    def main_loop(self, value):
423
        self._main_loop = value
424
        for sub in self.sub_extensions:
425
            sub.main_loop = value
426
427
    def do(self, which_callback, *args):
428
        pass
429
430
431
class FinishAfter(SimpleExtension):
432
    """Finishes the training process when triggered."""
433
    def __init__(self, **kwargs):
434
        super(FinishAfter, self).__init__(**kwargs)
435
436
    def do(self, which_callback, *args):
437
        self.main_loop.log.current_row['training_finish_requested'] = True
438
439
440
class Printing(SimpleExtension):
441
    """Prints log messages to the screen."""
442
    def __init__(self, **kwargs):
443
        kwargs.setdefault("before_first_epoch", True)
444
        kwargs.setdefault("on_resumption", True)
445
        kwargs.setdefault("after_training", True)
446
        kwargs.setdefault("after_epoch", True)
447
        kwargs.setdefault("on_interrupt", True)
448
        super(Printing, self).__init__(**kwargs)
449
450
    def _print_attributes(self, attribute_tuples):
451
        for attr, value in sorted(attribute_tuples.items(), key=first):
452
            if not attr.startswith("_"):
453
                print("\t", "{}:".format(attr), value)
454
455
    def do(self, which_callback, *args):
456
        log = self.main_loop.log
457
        print_status = True
458
459
        print()
460
        print("".join(79 * "-"))
461
        if which_callback == "before_epoch" and log.status['epochs_done'] == 0:
462
            print("BEFORE FIRST EPOCH")
463
        elif which_callback == "on_resumption":
464
            print("TRAINING HAS BEEN RESUMED")
465
        elif which_callback == "after_training":
466
            print("TRAINING HAS BEEN FINISHED:")
467
        elif which_callback == "after_epoch":
468
            print("AFTER ANOTHER EPOCH")
469
        elif which_callback == "on_interrupt":
470
            print("TRAINING HAS BEEN INTERRUPTED")
471
            print_status = False
472
        print("".join(79 * "-"))
473
        if print_status:
474
            print("Training status:")
475
            self._print_attributes(log.status)
476
            print("Log records from the iteration {}:".format(
477
                log.status['iterations_done']))
478
            self._print_attributes(log.current_row)
479
        print()
480
481
482
class ProgressBar(TrainingExtension):
483
    """Display a progress bar during training.
484
485
    This extension tries to infer the number of iterations per epoch
486
    by querying the `num_batches`, `num_examples` and `batch_size`
487
    attributes from the :class:`IterationScheme`. When this information is
488
    not available it will display a simplified progress bar that does not
489
    include the estimated time until the end of this epoch.
490
491
    Notes
492
    -----
493
    This extension should be run before other extensions that print to
494
    the screen at the end or at the beginning of the epoch (e.g. the
495
    :class:`Printing` extension). Placing ProgressBar before these
496
    extension will ensure you won't get intermingled output on your
497
    terminal.
498
499
    """
500
    def __init__(self, **kwargs):
501
        super(ProgressBar, self).__init__(**kwargs)
502
        self.bar = None
503
        self.iter_count = 0
504
505
    def __getstate__(self):
506
        # Ensure we won't pickle the actual progress bar.
507
        # (It might contain unpicklable file handles)
508
        state = dict(self.__dict__)
509
        del state['bar']
510
        return state
511
512
    def __setstate__(self, state):
513
        self.__dict__.update(state)
514
        self.bar = None
515
516
    def get_iter_per_epoch(self):
517
        """Try to infer the number of iterations per epoch."""
518
        iter_scheme = self.main_loop.data_stream.iteration_scheme
519
        if hasattr(iter_scheme, 'num_batches'):
520
            return iter_scheme.num_batches
521
        elif (hasattr(iter_scheme, 'num_examples') and
522
                hasattr(iter_scheme, 'batch_size')):
523
            return iter_scheme.num_examples // iter_scheme.batch_size
524
        return None
525
526
    def create_bar(self):
527
        """Create a new progress bar.
528
529
        Calls `self.get_iter_per_epoch()`, selects an appropriate
530
        set of widgets and creates a ProgressBar.
531
532
        """
533
        iter_per_epoch = self.get_iter_per_epoch()
534
        epochs_done = self.main_loop.log.status['epochs_done']
535
536
        if iter_per_epoch is None:
537
            widgets = ["Epoch {}, step ".format(epochs_done),
538
                       progressbar.Counter(), ' ',
539
                       progressbar.BouncingBar(), ' ',
540
                       progressbar.Timer()]
541
            iter_per_epoch = progressbar.UnknownLength
542
        else:
543
            widgets = ["Epoch {}, step ".format(epochs_done),
544
                       progressbar.Counter(),
545
                       ' (', progressbar.Percentage(), ') ',
546
                       progressbar.Bar(), ' ',
547
                       progressbar.Timer(), ' ', progressbar.ETA()]
548
549
        return progressbar.ProgressBar(widgets=widgets,
550
                                       max_value=iter_per_epoch)
551
552
    def before_epoch(self):
553
        self.iter_count = 0
554
555
    def after_epoch(self):
556
        if self.bar is None:
557
            return
558
559
        self.bar.finish()
560
        self.bar = None
561
562
    def before_batch(self, batch):
563
        if self.bar is None:
564
            self.bar = self.create_bar()
565
            self.bar.start()
566
567
        self.iter_count += 1
568
        self.bar.update(self.iter_count)
569
570
571
class Timing(SimpleExtension):
572
    """Add timing information to the log.
573
574
    This adds data about the time spent in the algorithm's
575
    :meth:`~.Algorithm.process_batch` method as well as the time spent
576
    reading data per batch or epoch. It also reports the time spent
577
    initializing the algorithm.
578
579
    Notes
580
    -----
581
    Add this extension *before* the :class:`Printing` extension.
582
583
    This extension does *not* enable full profiling information. To see a
584
    full profile of the main loop at the end of training, use the
585
    ``profile`` configuration (e.g.  by setting ``BLOCKS_PROFILE=true``).
586
587
    """
588
    def __init__(self, **kwargs):
589
        kwargs.setdefault('before_first_epoch', True)
590
        kwargs.setdefault('after_epoch', True)
591
        super(Timing, self).__init__(**kwargs)
592
        self.current = {
593
            level: {'train': 0, 'read_data': 0}
594
            for level in ['batch', 'epoch']
595
        }
596
        self.previous = {
597
            level: {'train': 0, 'read_data': 0}
598
            for level in ['batch', 'epoch']
599
        }
600
601
    def do(self, which_callback, *args):
602
        current_row = self.main_loop.log.current_row
603
        profile = self.main_loop.profile.total
604
605
        if which_callback == 'before_epoch':
606
            current_row['time_initialization'] = profile[('initialization',)]
607
            return
608
        if which_callback == 'after_batch':
609
            level = 'batch'
610
        elif which_callback == 'after_epoch':
611
            level = 'epoch'
612
        for action in ['train', 'read_data']:
613
            self.previous[level][action] = self.current[level][action]
614
            self.current[level][action] = profile['training', 'epoch', action]
615
            current_row['time_{}_this_{}'.format(action, level)] = \
616
                self.current[level][action] - self.previous[level][action]
617
            current_row['time_{}_total'.format(action)] = \
618
                self.current[level][action]
619