Completed
Push — master ( e89cf9...699e27 )
by Dmitry
55:35
created

ProgressBar.get_iter_per_epoch()   B

Complexity

Conditions 7

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 1
Metric Value
cc 7
c 1
b 0
f 1
dl 0
loc 16
rs 7.3333
1
from __future__ import print_function
2
3
import datetime
4
import logging
5
from abc import ABCMeta, abstractmethod
6
7
import progressbar
8
from six import add_metaclass
9
from toolz import first
10
11
logger = logging.getLogger(__name__)
12
13
14
def callback(func):
15
    func._is_callback = True
16
    return func
17
18
19
class TrainingExtension(object):
20
    """The base class for training extensions.
21
22
    An extension is a set of callbacks sharing a joint context that are
23
    invoked at certain stages of the training procedure. These callbacks
24
    typically add a certain functionality to the training procedure,
25
    e.g. running validation on auxiliary datasets or early stopping.
26
27
    Parameters
28
    ----------
29
    name : str, optional
30
        The name of the extension. The names are useful in order to
31
        distinguish between several extensions of the same type that
32
        belongs to the same main loop. By default the name is set to
33
        the name of the class.
34
35
    Attributes
36
    ----------
37
    main_loop : :class:`.MainLoop`
38
        The main loop to which the extension belongs.
39
    name : str
40
        The name of the extension.
41
42
    """
43
    def __init__(self, name=None):
44
        if not name:
45
            name = self.__class__.__name__
46
        self.name = name
47
48
    @property
49
    def main_loop(self):
50
        if not hasattr(self, '_main_loop'):
51
            raise ValueError("main loop must be assigned to extension first")
52
        return self._main_loop
53
54
    @main_loop.setter
55
    def main_loop(self, value):
56
        self._main_loop = value
57
58
    def dispatch(self, callback_name, *args):
59
        """Runs callback with the given name.
60
61
        The reason for having this method is to allow
62
        the descendants of the :class:`TrainingExtension` to intercept
63
        callback invocations and do something with them, e.g. block
64
        when certain condition does not hold. The default implementation
65
        simply invokes the callback by its name.
66
67
        """
68
        getattr(self, str(callback_name))(*args)
69
70
    @callback
71
    def on_resumption(self):
72
        """The callback invoked after training is resumed."""
73
        pass
74
75
    @callback
76
    def on_error(self, exception):
77
        """The callback invoked when an error occurs.
78
79
        Parameters
80
        ----------
81
        exception : object
82
            Exception occurred during the main loop run.
83
84
        """
85
86
    @callback
87
    def before_training(self):
88
        """The callback invoked before training is started."""
89
        pass
90
91
    @callback
92
    def before_epoch(self):
93
        """The callback invoked before starting an epoch."""
94
        pass
95
96
    @callback
97
    def before_batch(self, batch):
98
        """The callback invoked before a batch is processed.
99
100
        Parameters
101
        ----------
102
        batch : object
103
            The data batch to be processed.
104
105
        """
106
        pass
107
108
    @callback
109
    def after_batch(self, batch):
110
        """The callback invoked after a batch is processed.
111
112
        Parameters
113
        ----------
114
        batch : object
115
            The data batch just processed.
116
117
        """
118
        pass
119
120
    @callback
121
    def after_epoch(self):
122
        """The callback invoked after an epoch is finished."""
123
        pass
124
125
    @callback
126
    def after_training(self):
127
        """The callback invoked after training is finished."""
128
        pass
129
130
    @callback
131
    def on_interrupt(self):
132
        """The callback invoked when training is interrupted."""
133
        pass
134
135
136
class CallbackName(str):
137
    """A name of a TrainingExtension callback.
138
139
    Raises
140
    ------
141
    :class:`TypeError` on comparison with a string which is not a name of
142
    TrainingExtension callback.
143
144
    """
145
    def __eq__(self, other):
146
        callback_names = [key for key, value
147
                          in TrainingExtension.__dict__.items()
148
                          if getattr(value, '_is_callback', False)]
149
        if other not in callback_names:
150
            raise TypeError("{} is not a valid callback.".format(other))
151
        return str(self) == other
152
153
154
class Predicate(object):
155
    def __init__(self, condition, num):
156
        self.condition = condition
157
        self.num = num
158
159
    def __call__(self, log):
160
        if self.condition.endswith('epochs'):
161
            entry = log.status['epochs_done']
162
        else:
163
            entry = log.status['iterations_done']
164
        if self.condition.startswith('every'):
165
            return entry % self.num == 0
166
        else:
167
            return entry == self.num
168
169
170
def has_done_epochs(log):
171
    return log.status['epochs_done'] == 0
172
173
174
def always_true(log):
175
    return True
176
177
178
@add_metaclass(ABCMeta)
179
class SimpleExtension(TrainingExtension):
180
    """A base class for simple extensions.
181
182
    All logic of simple extensions is concentrated in the method
183
    :meth:`do`.  This method is called when certain conditions are
184
    fulfilled. The user can manage the conditions by calling the
185
    `add_condition` method and by passing arguments to the constructor.  In
186
    addition to specifying when :meth:`do` is called, it is possible to
187
    specify additional arguments passed to :meth:`do` under different
188
    conditions.
189
190
    Parameters
191
    ----------
192
    before_training : bool
193
        If ``True``, :meth:`do` is invoked before training.
194
    before_first_epoch : bool
195
        If ``True``, :meth:`do` is invoked before the first epoch.
196
    before_epoch : bool
197
        If ``True``, :meth:`do` is invoked before every epoch.
198
    on_resumption : bool, optional
199
        If ``True``, :meth:`do` is invoked when training is resumed.
200
    on_interrupt : bool, optional
201
        If ``True``, :meth:`do` is invoked when training is interrupted.
202
    after_epoch : bool
203
        If ``True``, :meth:`do` is invoked after every epoch.
204
    after_batch: bool
205
        If ``True``, :meth:`do` is invoked after every batch.
206
    after_training : bool
207
        If ``True``, :meth:`do` is invoked after training.
208
    after_n_epochs : int, optional
209
        If not ``None``, :meth:`do` is invoked when `after_n_epochs`
210
        epochs are done.
211
    every_n_epochs : int, optional
212
        If not ``None``, :meth:`do` is invoked after every n-th epoch.
213
    after_n_batches : int, optional
214
        If not ``None``, :meth:`do` is invoked when `after_n_batches`
215
        batches are processed.
216
    every_n_batches : int, optional
217
        If not ``None``, :meth:`do` is invoked after every n-th batch.
218
219
    """
220
    BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
221
                                  "before_epoch", "before_batch",
222
                                  "on_resumption", "on_interrupt", "on_error",
223
                                  "after_epoch", "after_batch",
224
                                  "after_training"])
225
226
    INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches",
227
                                  "every_n_epochs", "every_n_batches"])
228
229
    def __init__(self, **kwargs):
230
        self._conditions = []
231
        super_kwargs = {}
232
        trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS
233
        conditions = {}
234
        for key, value in kwargs.items():
235
            if key in trigger_keywords:
236
                conditions[key] = value
237
            else:
238
                super_kwargs[key] = value
239
        self.set_conditions(**conditions)
240
        super(SimpleExtension, self).__init__(**super_kwargs)
241
242
    def set_conditions(self, **kwargs):
243
        """Set the conditions for which this extension should be run.
244
245
        Parameters
246
        ----------
247
        See the :class:`SimpleExtension` docstring for a list of
248
        possible parameters.
249
250
        """
251
        self._conditions[:] = []
252
        predicates = {'before_first_epoch': has_done_epochs}
253
        conditions = {
254
            'before_first_epoch': 'before_epoch',
255
            'after_epoch': 'after_epoch',
256
            'after_batch': 'after_batch',
257
            'every_n_batches': 'after_batch',
258
            'every_n_epochs': 'after_epoch',
259
            'after_n_batches': 'after_batch',
260
            'after_n_epochs': 'after_epoch'
261
        }
262
        # Freeze the keys as a list so that we can safely modify kwargs.
263
        for key, value in kwargs.items():
264
            if value:
265
                if key in self.BOOLEAN_TRIGGERS:
266
                    self.add_condition([conditions.get(key, key)],
267
                                       predicate=predicates.get(key, None))
268
                elif key in self.INTEGER_TRIGGERS:
269
                    predicate = Predicate(key, value)
270
                    self.add_condition([conditions.get(key, key)],
271
                                       predicate=predicate)
272
                else:
273
                    raise KeyError("Invalid condition: {}".format(key))
274
        return self  # For chaining calls.
275
276
    def add_condition(self, callbacks_names, predicate=None, arguments=None):
277
        """Adds a condition under which a :meth:`do` is called.
278
279
        Parameters
280
        ----------
281
        callbacks_names : list of str
282
            The names of the callback in which the method.
283
        predicate : function
284
            A predicate function the main loop's log as the
285
            single parameter and returning ``True`` when the method
286
            should be called and ``False`` when should not. If ``None``,
287
            an always ``True`` predicate is used.
288
        arguments : iterable
289
            Additional arguments to be passed to :meth:`do`. They will
290
            be concatenated with the ones passed from the main loop
291
            (e.g. the batch in case of `after_epoch` callback).
292
293
        Returns
294
        -------
295
            The extension object (allow chaining calls)
296
297
        """
298
        if not isinstance(callbacks_names, (list, tuple)):
299
            raise ValueError("callbacks_names must be list or tuple.")
300
        for _callback_name in callbacks_names:
301
            if not arguments:
302
                arguments = []
303
            if not predicate:
304
                self._conditions.append((_callback_name, always_true,
305
                                        arguments))
306
            else:
307
                self._conditions.append((_callback_name, predicate,
308
                                        arguments))
309
        return self
310
311
    @abstractmethod
312
    def do(self, which_callback, *args):
313
        r"""Does the job of the training extension.
314
315
        Parameters
316
        ----------
317
        which_callback : str
318
            The name of the callback in the context of which :meth:`do` is
319
            run.
320
        \*args : tuple
321
            The arguments from the main loop concatenated with additional
322
            arguments from user.
323
324
        Notes
325
        -----
326
        Subclasses *must* accept additional positional arguments in their
327
        call signature for this method, even if they are unused.
328
329
        """
330
        pass
331
332
    def dispatch(self, callback_invoked, *from_main_loop):
333
        """Check conditions and call the :meth:`do` method.
334
335
        Also adds additional arguments if specified for a condition.
336
337
        .. todo::
338
339
            Add a check for a situation when several conditions are met
340
            at the same time and do something.
341
342
        """
343
        for callback_name, predicate, arguments in self._conditions:
344
            if (callback_name == callback_invoked and
345
                    predicate(self.main_loop.log)):
346
                self.do(callback_invoked, *(from_main_loop + tuple(arguments)))
347
348
    @staticmethod
349
    def parse_args(which_callback, args):
350
        """Separates :meth:`do` arguments coming from different sources.
351
352
        When a :meth:`do` method receives arguments from both the main
353
        loop (e.g. a batch) and the user, it often has to separate them.
354
        This method is the right tool to use.
355
356
        Parameters
357
        ----------
358
        which_callback : str
359
            The name of the callback.
360
        args : iterable
361
            The arguments.
362
363
        Returns
364
        -------
365
        from_main_loop : tuple
366
        from_user : tuple
367
368
        """
369
        args = tuple(args)
370
        if (which_callback == 'after_batch' or
371
                which_callback == 'before_batch'):
372
            return (args[0],), args[1:]
373
        return (), args
374
375
376
class CompositeExtension(SimpleExtension):
377
    """An extension that manages several other extensions.
378
379
    Parameters
380
    ----------
381
    sub_extensions : iterable
382
        An iterable collection of sub-extensions to manage.
383
    run_before_children : bool, optional
384
        Whether the container extension's own logic should
385
        be dispatched before that of the sub-extensions.
386
        If ``False``, the containing extension is dispatched last.
387
        Defaults to ``True``.
388
389
    Notes
390
    -----
391
    The main use case for this class is bundling together groups
392
    of extensions that are most commonly used in tandem, configured
393
    so as to interact with one another. Encapsulating this pattern
394
    in a single extension reduces boilerplate.
395
396
    Sub-extensions are dispatched in the order specified in
397
    ``sub_extensions``, on whatever triggers they are individually
398
    configured to respect.
399
400
    Sub-extensions may be run on different triggers than the containing
401
    extension; the trigger keywords passed to the constructor
402
    for this class only affect the outer extension's logic, and
403
    sub-extensions should be configured independently (possibly in
404
    a constructor for a subclass of :class:`CompositeExtension`).
405
406
    """
407
    def __init__(self, sub_extensions, run_before_children=True, **kwargs):
408
        self.sub_extensions = sub_extensions
409
        self.run_before_children = run_before_children
410
        super(CompositeExtension, self).__init__(**kwargs)
411
412
    def dispatch(self, callback_invoked, *from_main_loop):
413
        def run_super():
414
            super(CompositeExtension, self).dispatch(callback_invoked,
415
                                                     *from_main_loop)
416
        if self.run_before_children:
417
            run_super()
418
419
        for ext in self.sub_extensions:
420
            ext.dispatch(callback_invoked, *from_main_loop)
421
422
        if not self.run_before_children:
423
            run_super()
424
425
    @property
426
    def main_loop(self):
427
        return super(CompositeExtension, self).main_loop
428
429
    @main_loop.setter
430
    def main_loop(self, value):
431
        self._main_loop = value
432
        for sub in self.sub_extensions:
433
            sub.main_loop = value
434
435
    def do(self, which_callback, *args):
436
        pass
437
438
439
class FinishAfter(SimpleExtension):
440
    """Finishes the training process when triggered."""
441
    def __init__(self, **kwargs):
442
        super(FinishAfter, self).__init__(**kwargs)
443
444
    def do(self, which_callback, *args):
445
        self.main_loop.log.current_row['training_finish_requested'] = True
446
447
448
class Printing(SimpleExtension):
449
    """Prints log messages to the screen."""
450
    def __init__(self, **kwargs):
451
        kwargs.setdefault("before_first_epoch", True)
452
        kwargs.setdefault("on_resumption", True)
453
        kwargs.setdefault("after_training", True)
454
        kwargs.setdefault("after_epoch", True)
455
        kwargs.setdefault("on_interrupt", True)
456
        super(Printing, self).__init__(**kwargs)
457
458
    def _print_attributes(self, attribute_tuples):
459
        for attr, value in sorted(attribute_tuples.items(), key=first):
460
            if not attr.startswith("_"):
461
                print("\t", "{}:".format(attr), value)
462
463
    def do(self, which_callback, *args):
464
        log = self.main_loop.log
465
        print_status = True
466
467
        print()
468
        print("".join(79 * "-"))
469
        if which_callback == "before_epoch" and log.status['epochs_done'] == 0:
470
            print("BEFORE FIRST EPOCH")
471
        elif which_callback == "on_resumption":
472
            print("TRAINING HAS BEEN RESUMED")
473
        elif which_callback == "after_training":
474
            print("TRAINING HAS BEEN FINISHED:")
475
        elif which_callback == "after_epoch":
476
            print("AFTER ANOTHER EPOCH")
477
        elif which_callback == "on_interrupt":
478
            print("TRAINING HAS BEEN INTERRUPTED")
479
            print_status = False
480
        print("".join(79 * "-"))
481
        if print_status:
482
            print("Training status:")
483
            self._print_attributes(log.status)
484
            print("Log records from the iteration {}:".format(
485
                log.status['iterations_done']))
486
            self._print_attributes(log.current_row)
487
        print()
488
489
490
class ProgressBar(TrainingExtension):
491
    """Display a progress bar during training.
492
493
    This extension tries to infer the number of iterations per epoch
494
    by querying the `num_batches`, `num_examples` and `batch_size`
495
    attributes from the :class:`IterationScheme`. When this information is
496
    not available it will display a simplified progress bar that does not
497
    include the estimated time until the end of this epoch.
498
499
    Notes
500
    -----
501
    This extension should be run before other extensions that print to
502
    the screen at the end or at the beginning of the epoch (e.g. the
503
    :class:`Printing` extension). Placing ProgressBar before these
504
    extension will ensure you won't get intermingled output on your
505
    terminal.
506
507
    """
508
    def __init__(self, **kwargs):
509
        super(ProgressBar, self).__init__(**kwargs)
510
        self.bar = None
511
        self.iter_count = 0
512
513
    def __getstate__(self):
514
        # Ensure we won't pickle the actual progress bar.
515
        # (It might contain unpicklable file handles)
516
        state = dict(self.__dict__)
517
        del state['bar']
518
        return state
519
520
    def __setstate__(self, state):
521
        self.__dict__.update(state)
522
        self.bar = None
523
524
    def get_iter_per_epoch(self):
525
        """Try to infer the number of iterations per epoch."""
526
        iter_scheme = self.main_loop.data_stream.iteration_scheme
527
        if (hasattr(iter_scheme, 'indices') and
528
                not hasattr(iter_scheme, 'batch_size')):
529
            # schemes that extend IndexScheme
530
            return len(iter_scheme.indices)
531
        elif (hasattr(iter_scheme, 'indices') and
532
              hasattr(iter_scheme, 'batch_size')):
533
            # schemes that extend BatchScheme
534
            return len(iter_scheme.indices) // iter_scheme.batch_size
535
        elif (hasattr(iter_scheme, 'num_examples') and
536
              hasattr(iter_scheme, 'batch_size')):
537
            # ConstantScheme
538
            return iter_scheme.num_examples // iter_scheme.batch_size
539
        return None
540
541
    def create_bar(self):
542
        """Create a new progress bar.
543
544
        Calls `self.get_iter_per_epoch()`, selects an appropriate
545
        set of widgets and creates a ProgressBar.
546
547
        """
548
        iter_per_epoch = self.get_iter_per_epoch()
549
        epochs_done = self.main_loop.log.status['epochs_done']
550
551
        if iter_per_epoch is None:
552
            widgets = ["Epoch {}, step ".format(epochs_done),
553
                       progressbar.Counter(), ' ',
554
                       progressbar.BouncingBar(), ' ',
555
                       progressbar.Timer()]
556
            iter_per_epoch = progressbar.UnknownLength
557
        else:
558
            widgets = ["Epoch {}, step ".format(epochs_done),
559
                       progressbar.Counter(),
560
                       ' (', progressbar.Percentage(), ') ',
561
                       progressbar.Bar(), ' ',
562
                       progressbar.Timer(), ' ', progressbar.ETA()]
563
564
        return progressbar.ProgressBar(widgets=widgets,
565
                                       max_value=iter_per_epoch)
566
567
    def before_epoch(self):
568
        self.iter_count = 0
569
570
    def after_epoch(self):
571
        if self.bar is None:
572
            return
573
574
        self.bar.finish()
575
        self.bar = None
576
577
    def before_batch(self, batch):
578
        if self.bar is None:
579
            self.bar = self.create_bar()
580
            self.bar.start()
581
582
        self.iter_count += 1
583
        self.bar.update(self.iter_count)
584
585
586
class Timing(SimpleExtension):
587
    """Add timing information to the log.
588
589
    This adds data about the time spent in the algorithm's
590
    :meth:`~.Algorithm.process_batch` method as well as the time spent
591
    reading data per batch or epoch. It also reports the time spent
592
    initializing the algorithm.
593
594
    Parameters
595
    ----------
596
    prefix : str
597
        Prefix to be added to the log record. Defaults to the empty string.
598
599
    Notes
600
    -----
601
    Add this extension *before* the :class:`Printing` extension.
602
603
    Created with callbacks like ``every_n_batches`` this extension
604
    averages the time.
605
606
    This extension does *not* enable full profiling information. To see a
607
    full profile of the main loop at the end of training, use the
608
    ``profile`` configuration (e.g.  by setting ``BLOCKS_PROFILE=true``).
609
610
    """
611
    def __init__(self, prefix="", **kwargs):
612
        kwargs.setdefault('before_first_epoch', True)
613
        kwargs.setdefault('after_epoch', True)
614
        super(Timing, self).__init__(**kwargs)
615
616
        def init_dict():
617
            return {
618
                level: {'train': 0, 'read_data': 0}
619
                for level in ['batch', 'epoch']}
620
        self.current = init_dict()
621
        self.previous = init_dict()
622
        self.current_index = init_dict()
623
        self.previous_index = init_dict()
624
        self.prefix = prefix
625
        if self.prefix:
626
            self.prefix += '_'
627
628
    def do(self, which_callback, *args):
629
        current_row = self.main_loop.log.current_row
630
        profile = self.main_loop.profile.total
631
632
        if which_callback == 'before_epoch':
633
            current_row['time_initialization'] = profile[('initialization',)]
634
            return
635
        if which_callback == 'after_batch':
636
            level = 'batch'
637
            counter = 'iterations_done'
638
        elif which_callback == 'after_epoch':
639
            level = 'epoch'
640
            counter = 'epochs_done'
641
        else:
642
            raise ValueError('wrong callback type `{}`'.format(which_callback))
643
        for action in ['train', 'read_data']:
644
            self.previous_index[level][action] = (
645
                self.current_index[level][action])
646
            self.current_index[level][action] = (
647
                self.main_loop.log.status[counter])
648
            current_index = self.current_index[level][action]
649
            previous_index = self.previous_index[level][action]
650
            if current_index == previous_index:
651
                logger.debug('Timing extension was called twice this %s, '
652
                             'log was not updated.', level)
653
                # Nothing to report for this level
654
                continue
655
656
            self.previous[level][action] = self.current[level][action]
657
            self.current[level][action] = profile['training', 'epoch', action]
658
659
            this_time = self.prefix + 'time_{}_this_{}'
660
            current_row[this_time.format(action, level)] = (
661
                (self.current[level][action] - self.previous[level][action]) /
662
                (current_index - previous_index))
663
            total_time = self.prefix + 'time_{}_total'
664
            current_row[total_time.format(action)] = \
665
                self.current[level][action]
666
667
668
class Timestamp(SimpleExtension):
669
    """Adds a human readable (ISO 8601) timestamp to the log.
670
671
    Parameters
672
    ----------
673
    log_record : str, optional
674
        The record name to use. Defaults to 'timestamp'.
675
    separator : str, optional
676
        Separator between the date and time. ISO 8601 specifies 'T'.
677
        Here, we default to ' ' (blank space) for human readability.
678
679
    Notes
680
    -----
681
    By default, triggers after every epoch as well as before training
682
    starts, after training finishes, when an error occurs or when training
683
    is interrupted or resumed, as these are all generally useful
684
    circumstances for which to have a timestamp. These can be disabled
685
    by passing `False` as the appropriate keyword argument; see
686
    :class:`SimpleExtension`.
687
688
    """
689
    DEFAULT_LOG_RECORD = 'timestamp'
690
691
    def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ',
692
                 **kwargs):
693
        self.log_record = log_record
694
        self.separator = separator
695
        default_callbacks = ['before_training', 'after_epoch', 'on_error',
696
                             'on_interrupt', 'on_resumption', 'after_training']
697
        for callback in default_callbacks:
698
            kwargs.setdefault(callback, True)
699
        super(Timestamp, self).__init__(**kwargs)
700
701
    def do(self, *args):
702
        self.main_loop.log.current_row[self.log_record] = self.get_timestamp()
703
704
    def get_timestamp(self):
705
        # Separated into a method to override for ease of testing.
706
        return datetime.datetime.now().isoformat(self.separator)
707