Completed
Push — master ( 5b14e7...41b8fa )
by David
55:16
created

Timestamp.do()   A

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 2
rs 10
1
from __future__ import print_function
2
3
import datetime
4
import logging
5
from abc import ABCMeta, abstractmethod
6
7
import progressbar
8
from six import add_metaclass
9
from toolz import first
10
11
logger = logging.getLogger(__name__)
12
13
14
def callback(func):
15
    func._is_callback = True
16
    return func
17
18
19
class TrainingExtension(object):
20
    """The base class for training extensions.
21
22
    An extension is a set of callbacks sharing a joint context that are
23
    invoked at certain stages of the training procedure. These callbacks
24
    typically add a certain functionality to the training procedure,
25
    e.g. running validation on auxiliary datasets or early stopping.
26
27
    Parameters
28
    ----------
29
    name : str, optional
30
        The name of the extension. The names are useful in order to
31
        distinguish between several extensions of the same type that
32
        belongs to the same main loop. By default the name is set to
33
        the name of the class.
34
35
    Attributes
36
    ----------
37
    main_loop : :class:`.MainLoop`
38
        The main loop to which the extension belongs.
39
    name : str
40
        The name of the extension.
41
42
    """
43
    def __init__(self, name=None):
44
        if not name:
45
            name = self.__class__.__name__
46
        self.name = name
47
48
    @property
49
    def main_loop(self):
50
        if not hasattr(self, '_main_loop'):
51
            raise ValueError("main loop must be assigned to extension first")
52
        return self._main_loop
53
54
    @main_loop.setter
55
    def main_loop(self, value):
56
        self._main_loop = value
57
58
    def dispatch(self, callback_name, *args):
59
        """Runs callback with the given name.
60
61
        The reason for having this method is to allow
62
        the descendants of the :class:`TrainingExtension` to intercept
63
        callback invocations and do something with them, e.g. block
64
        when certain condition does not hold. The default implementation
65
        simply invokes the callback by its name.
66
67
        """
68
        getattr(self, str(callback_name))(*args)
69
70
    @callback
71
    def on_resumption(self):
72
        """The callback invoked after training is resumed."""
73
        pass
74
75
    @callback
76
    def on_error(self, exception):
77
        """The callback invoked when an error occurs.
78
79
        Parameters
80
        ----------
81
        exception : object
82
            Exception occurred during the main loop run.
83
84
        """
85
86
    @callback
87
    def before_training(self):
88
        """The callback invoked before training is started."""
89
        pass
90
91
    @callback
92
    def before_epoch(self):
93
        """The callback invoked before starting an epoch."""
94
        pass
95
96
    @callback
97
    def before_batch(self, batch):
98
        """The callback invoked before a batch is processed.
99
100
        Parameters
101
        ----------
102
        batch : object
103
            The data batch to be processed.
104
105
        """
106
        pass
107
108
    @callback
109
    def after_batch(self, batch):
110
        """The callback invoked after a batch is processed.
111
112
        Parameters
113
        ----------
114
        batch : object
115
            The data batch just processed.
116
117
        """
118
        pass
119
120
    @callback
121
    def after_epoch(self):
122
        """The callback invoked after an epoch is finished."""
123
        pass
124
125
    @callback
126
    def after_training(self):
127
        """The callback invoked after training is finished."""
128
        pass
129
130
    @callback
131
    def on_interrupt(self):
132
        """The callback invoked when training is interrupted."""
133
        pass
134
135
136
class CallbackName(str):
137
    """A name of a TrainingExtension callback.
138
139
    Raises
140
    ------
141
    :class:`TypeError` on comparison with a string which is not a name of
142
    TrainingExtension callback.
143
144
    """
145
    def __eq__(self, other):
146
        callback_names = [key for key, value
147
                          in TrainingExtension.__dict__.items()
148
                          if getattr(value, '_is_callback', False)]
149
        if other not in callback_names:
150
            raise TypeError("{} is not a valid callback.".format(other))
151
        return str(self) == other
152
153
154
class Predicate(object):
155
    def __init__(self, condition, num):
156
        self.condition = condition
157
        self.num = num
158
159
    def __call__(self, log):
160
        if self.condition.endswith('epochs'):
161
            entry = log.status['epochs_done']
162
        else:
163
            entry = log.status['iterations_done']
164
        if self.condition.startswith('every'):
165
            return entry % self.num == 0
166
        else:
167
            return entry == self.num
168
169
170
def has_done_epochs(log):
171
    return log.status['epochs_done'] == 0
172
173
174
def always_true(log):
175
    return True
176
177
178
@add_metaclass(ABCMeta)
179
class SimpleExtension(TrainingExtension):
180
    """A base class for simple extensions.
181
182
    All logic of simple extensions is concentrated in the method
183
    :meth:`do`.  This method is called when certain conditions are
184
    fulfilled. The user can manage the conditions by calling the
185
    `add_condition` method and by passing arguments to the constructor.  In
186
    addition to specifying when :meth:`do` is called, it is possible to
187
    specify additional arguments passed to :meth:`do` under different
188
    conditions.
189
190
    Parameters
191
    ----------
192
    before_training : bool
193
        If ``True``, :meth:`do` is invoked before training.
194
    before_first_epoch : bool
195
        If ``True``, :meth:`do` is invoked before the first epoch.
196
    before_epoch : bool
197
        If ``True``, :meth:`do` is invoked before every epoch.
198
    on_resumption : bool, optional
199
        If ``True``, :meth:`do` is invoked when training is resumed.
200
    on_interrupt : bool, optional
201
        If ``True``, :meth:`do` is invoked when training is interrupted.
202
    after_epoch : bool
203
        If ``True``, :meth:`do` is invoked after every epoch.
204
    after_batch: bool
205
        If ``True``, :meth:`do` is invoked after every batch.
206
    after_training : bool
207
        If ``True``, :meth:`do` is invoked after training.
208
    after_n_epochs : int, optional
209
        If not ``None``, :meth:`do` is invoked when `after_n_epochs`
210
        epochs are done.
211
    every_n_epochs : int, optional
212
        If not ``None``, :meth:`do` is invoked after every n-th epoch.
213
    after_n_batches : int, optional
214
        If not ``None``, :meth:`do` is invoked when `after_n_batches`
215
        batches are processed.
216
    every_n_batches : int, optional
217
        If not ``None``, :meth:`do` is invoked after every n-th batch.
218
219
    """
220
    BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
221
                                  "before_epoch", "before_batch",
222
                                  "on_resumption", "on_interrupt",
223
                                  "after_epoch", "after_batch",
224
                                  "after_training"])
225
226
    INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches",
227
                                  "every_n_epochs", "every_n_batches"])
228
229
    def __init__(self, **kwargs):
230
        self._conditions = []
231
        super_kwargs = {}
232
        trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS
233
        conditions = {}
234
        for key, value in kwargs.items():
235
            if key in trigger_keywords:
236
                conditions[key] = value
237
            else:
238
                super_kwargs[key] = value
239
        self.set_conditions(**conditions)
240
        super(SimpleExtension, self).__init__(**super_kwargs)
241
242
    def set_conditions(self, **kwargs):
243
        """Set the conditions for which this extension should be run.
244
245
        Parameters
246
        ----------
247
        See the :class:`SimpleExtension` docstring for a list of
248
        possible parameters.
249
250
        """
251
        self._conditions[:] = []
252
        predicates = {'before_first_epoch': has_done_epochs}
253
        conditions = {
254
            'before_first_epoch': 'before_epoch',
255
            'after_epoch': 'after_epoch',
256
            'after_batch': 'after_batch',
257
            'every_n_batches': 'after_batch',
258
            'every_n_epochs': 'after_epoch',
259
            'after_n_batches': 'after_batch',
260
            'after_n_epochs': 'after_epoch'
261
        }
262
        # Freeze the keys as a list so that we can safely modify kwargs.
263
        for key, value in kwargs.items():
264
            if value:
265
                if key in self.BOOLEAN_TRIGGERS:
266
                    self.add_condition([conditions.get(key, key)],
267
                                       predicate=predicates.get(key, None))
268
                elif key in self.INTEGER_TRIGGERS:
269
                    predicate = Predicate(key, value)
270
                    self.add_condition([conditions.get(key, key)],
271
                                       predicate=predicate)
272
                else:
273
                    raise KeyError("Invalid condition: {}".format(key))
274
        return self  # For chaining calls.
275
276
    def add_condition(self, callbacks_names, predicate=None, arguments=None):
277
        """Adds a condition under which a :meth:`do` is called.
278
279
        Parameters
280
        ----------
281
        callbacks_names : list of str
282
            The names of the callback in which the method.
283
        predicate : function
284
            A predicate function the main loop's log as the
285
            single parameter and returning ``True`` when the method
286
            should be called and ``False`` when should not. If ``None``,
287
            an always ``True`` predicate is used.
288
        arguments : iterable
289
            Additional arguments to be passed to :meth:`do`. They will
290
            be concatenated with the ones passed from the main loop
291
            (e.g. the batch in case of `after_epoch` callback).
292
293
        Returns
294
        -------
295
            The extension object (allow chaining calls)
296
297
        """
298
        if not isinstance(callbacks_names, (list, tuple)):
299
            raise ValueError("callbacks_names must be list or tuple.")
300
        for _callback_name in callbacks_names:
301
            if not arguments:
302
                arguments = []
303
            if not predicate:
304
                self._conditions.append((_callback_name, always_true,
305
                                        arguments))
306
            else:
307
                self._conditions.append((_callback_name, predicate,
308
                                        arguments))
309
        return self
310
311
    @abstractmethod
312
    def do(self, which_callback, *args):
313
        r"""Does the job of the training extension.
314
315
        Parameters
316
        ----------
317
        which_callback : str
318
            The name of the callback in the context of which :meth:`do` is
319
            run.
320
        \*args : tuple
321
            The arguments from the main loop concatenated with additional
322
            arguments from user.
323
324
        Notes
325
        -----
326
        Subclasses *must* accept additional positional arguments in their
327
        call signature for this method, even if they are unused.
328
329
        """
330
        pass
331
332
    def dispatch(self, callback_invoked, *from_main_loop):
333
        """Check conditions and call the :meth:`do` method.
334
335
        Also adds additional arguments if specified for a condition.
336
337
        .. todo::
338
339
            Add a check for a situation when several conditions are met
340
            at the same time and do something.
341
342
        """
343
        for callback_name, predicate, arguments in self._conditions:
344
            if (callback_name == callback_invoked and
345
                    predicate(self.main_loop.log)):
346
                self.do(callback_invoked, *(from_main_loop + tuple(arguments)))
347
348
    @staticmethod
349
    def parse_args(which_callback, args):
350
        """Separates :meth:`do` arguments coming from different sources.
351
352
        When a :meth:`do` method receives arguments from both the main
353
        loop (e.g. a batch) and the user, it often has to separate them.
354
        This method is the right tool to use.
355
356
        Parameters
357
        ----------
358
        which_callback : str
359
            The name of the callback.
360
        args : iterable
361
            The arguments.
362
363
        Returns
364
        -------
365
        from_main_loop : tuple
366
        from_user : tuple
367
368
        """
369
        args = tuple(args)
370
        if (which_callback == 'after_batch' or
371
                which_callback == 'before_batch'):
372
            return (args[0],), args[1:]
373
        return (), args
374
375
376
class CompositeExtension(SimpleExtension):
377
    """An extension that manages several other extensions.
378
379
    Parameters
380
    ----------
381
    sub_extensions : iterable
382
        An iterable collection of sub-extensions to manage.
383
    run_before_children : bool, optional
384
        Whether the container extension's own logic should
385
        be dispatched before that of the sub-extensions.
386
        If ``False``, the containing extension is dispatched last.
387
        Defaults to ``True``.
388
389
    Notes
390
    -----
391
    The main use case for this class is bundling together groups
392
    of extensions that are most commonly used in tandem, configured
393
    so as to interact with one another. Encapsulating this pattern
394
    in a single extension reduces boilerplate.
395
396
    Sub-extensions are dispatched in the order specified in
397
    ``sub_extensions``, on whatever triggers they are individually
398
    configured to respect.
399
400
    Sub-extensions may be run on different triggers than the containing
401
    extension; the trigger keywords passed to the constructor
402
    for this class only affect the outer extension's logic, and
403
    sub-extensions should be configured independently (possibly in
404
    a constructor for a subclass of :class:`CompositeExtension`).
405
406
    """
407
    def __init__(self, sub_extensions, run_before_children=True, **kwargs):
408
        self.sub_extensions = sub_extensions
409
        self.run_before_children = run_before_children
410
        super(CompositeExtension, self).__init__(**kwargs)
411
412
    def dispatch(self, callback_invoked, *from_main_loop):
413
        def run_super():
414
            super(CompositeExtension, self).dispatch(callback_invoked,
415
                                                     *from_main_loop)
416
        if self.run_before_children:
417
            run_super()
418
419
        for ext in self.sub_extensions:
420
            ext.dispatch(callback_invoked, *from_main_loop)
421
422
        if not self.run_before_children:
423
            run_super()
424
425
    @property
426
    def main_loop(self):
427
        return super(CompositeExtension, self).main_loop
428
429
    @main_loop.setter
430
    def main_loop(self, value):
431
        self._main_loop = value
432
        for sub in self.sub_extensions:
433
            sub.main_loop = value
434
435
    def do(self, which_callback, *args):
436
        pass
437
438
439
class FinishAfter(SimpleExtension):
440
    """Finishes the training process when triggered."""
441
    def __init__(self, **kwargs):
442
        super(FinishAfter, self).__init__(**kwargs)
443
444
    def do(self, which_callback, *args):
445
        self.main_loop.log.current_row['training_finish_requested'] = True
446
447
448
class Printing(SimpleExtension):
449
    """Prints log messages to the screen."""
450
    def __init__(self, **kwargs):
451
        kwargs.setdefault("before_first_epoch", True)
452
        kwargs.setdefault("on_resumption", True)
453
        kwargs.setdefault("after_training", True)
454
        kwargs.setdefault("after_epoch", True)
455
        kwargs.setdefault("on_interrupt", True)
456
        super(Printing, self).__init__(**kwargs)
457
458
    def _print_attributes(self, attribute_tuples):
459
        for attr, value in sorted(attribute_tuples.items(), key=first):
460
            if not attr.startswith("_"):
461
                print("\t", "{}:".format(attr), value)
462
463
    def do(self, which_callback, *args):
464
        log = self.main_loop.log
465
        print_status = True
466
467
        print()
468
        print("".join(79 * "-"))
469
        if which_callback == "before_epoch" and log.status['epochs_done'] == 0:
470
            print("BEFORE FIRST EPOCH")
471
        elif which_callback == "on_resumption":
472
            print("TRAINING HAS BEEN RESUMED")
473
        elif which_callback == "after_training":
474
            print("TRAINING HAS BEEN FINISHED:")
475
        elif which_callback == "after_epoch":
476
            print("AFTER ANOTHER EPOCH")
477
        elif which_callback == "on_interrupt":
478
            print("TRAINING HAS BEEN INTERRUPTED")
479
            print_status = False
480
        print("".join(79 * "-"))
481
        if print_status:
482
            print("Training status:")
483
            self._print_attributes(log.status)
484
            print("Log records from the iteration {}:".format(
485
                log.status['iterations_done']))
486
            self._print_attributes(log.current_row)
487
        print()
488
489
490
class ProgressBar(TrainingExtension):
491
    """Display a progress bar during training.
492
493
    This extension tries to infer the number of iterations per epoch
494
    by querying the `num_batches`, `num_examples` and `batch_size`
495
    attributes from the :class:`IterationScheme`. When this information is
496
    not available it will display a simplified progress bar that does not
497
    include the estimated time until the end of this epoch.
498
499
    Notes
500
    -----
501
    This extension should be run before other extensions that print to
502
    the screen at the end or at the beginning of the epoch (e.g. the
503
    :class:`Printing` extension). Placing ProgressBar before these
504
    extension will ensure you won't get intermingled output on your
505
    terminal.
506
507
    """
508
    def __init__(self, **kwargs):
509
        super(ProgressBar, self).__init__(**kwargs)
510
        self.bar = None
511
        self.iter_count = 0
512
513
    def __getstate__(self):
514
        # Ensure we won't pickle the actual progress bar.
515
        # (It might contain unpicklable file handles)
516
        state = dict(self.__dict__)
517
        del state['bar']
518
        return state
519
520
    def __setstate__(self, state):
521
        self.__dict__.update(state)
522
        self.bar = None
523
524
    def get_iter_per_epoch(self):
525
        """Try to infer the number of iterations per epoch."""
526
        iter_scheme = self.main_loop.data_stream.iteration_scheme
527
        if hasattr(iter_scheme, 'num_batches'):
528
            return iter_scheme.num_batches
529
        elif (hasattr(iter_scheme, 'num_examples') and
530
                hasattr(iter_scheme, 'batch_size')):
531
            return iter_scheme.num_examples // iter_scheme.batch_size
532
        return None
533
534
    def create_bar(self):
535
        """Create a new progress bar.
536
537
        Calls `self.get_iter_per_epoch()`, selects an appropriate
538
        set of widgets and creates a ProgressBar.
539
540
        """
541
        iter_per_epoch = self.get_iter_per_epoch()
542
        epochs_done = self.main_loop.log.status['epochs_done']
543
544
        if iter_per_epoch is None:
545
            widgets = ["Epoch {}, step ".format(epochs_done),
546
                       progressbar.Counter(), ' ',
547
                       progressbar.BouncingBar(), ' ',
548
                       progressbar.Timer()]
549
            iter_per_epoch = progressbar.UnknownLength
550
        else:
551
            widgets = ["Epoch {}, step ".format(epochs_done),
552
                       progressbar.Counter(),
553
                       ' (', progressbar.Percentage(), ') ',
554
                       progressbar.Bar(), ' ',
555
                       progressbar.Timer(), ' ', progressbar.ETA()]
556
557
        return progressbar.ProgressBar(widgets=widgets,
558
                                       max_value=iter_per_epoch)
559
560
    def before_epoch(self):
561
        self.iter_count = 0
562
563
    def after_epoch(self):
564
        if self.bar is None:
565
            return
566
567
        self.bar.finish()
568
        self.bar = None
569
570
    def before_batch(self, batch):
571
        if self.bar is None:
572
            self.bar = self.create_bar()
573
            self.bar.start()
574
575
        self.iter_count += 1
576
        self.bar.update(self.iter_count)
577
578
579
class Timing(SimpleExtension):
580
    """Add timing information to the log.
581
582
    This adds data about the time spent in the algorithm's
583
    :meth:`~.Algorithm.process_batch` method as well as the time spent
584
    reading data per batch or epoch. It also reports the time spent
585
    initializing the algorithm.
586
587
    Parameters
588
    ----------
589
    prefix : str
590
        Prefix to be added to the log record. Defaults to the empty string.
591
592
    Notes
593
    -----
594
    Add this extension *before* the :class:`Printing` extension.
595
596
    Created with callbacks like ``every_n_batches`` this extension
597
    averages the time.
598
599
    This extension does *not* enable full profiling information. To see a
600
    full profile of the main loop at the end of training, use the
601
    ``profile`` configuration (e.g.  by setting ``BLOCKS_PROFILE=true``).
602
603
    """
604
    def __init__(self, prefix="", **kwargs):
605
        kwargs.setdefault('before_first_epoch', True)
606
        kwargs.setdefault('after_epoch', True)
607
        super(Timing, self).__init__(**kwargs)
608
609
        def init_dict():
610
            return {
611
                level: {'train': 0, 'read_data': 0}
612
                for level in ['batch', 'epoch']}
613
        self.current = init_dict()
614
        self.previous = init_dict()
615
        self.current_index = init_dict()
616
        self.previous_index = init_dict()
617
        self.prefix = prefix
618
        if self.prefix:
619
            self.prefix += '_'
620
621
    def do(self, which_callback, *args):
622
        current_row = self.main_loop.log.current_row
623
        profile = self.main_loop.profile.total
624
625
        if which_callback == 'before_epoch':
626
            current_row['time_initialization'] = profile[('initialization',)]
627
            return
628
        if which_callback == 'after_batch':
629
            level = 'batch'
630
            counter = 'iterations_done'
631
        elif which_callback == 'after_epoch':
632
            level = 'epoch'
633
            counter = 'epochs_done'
634
        else:
635
            raise ValueError('wrong callback type `{}`'.format(which_callback))
636
        for action in ['train', 'read_data']:
637
            self.previous_index[level][action] = (
638
                self.current_index[level][action])
639
            self.current_index[level][action] = (
640
                self.main_loop.log.status[counter])
641
            current_index = self.current_index[level][action]
642
            previous_index = self.previous_index[level][action]
643
            if current_index == previous_index:
644
                logger.debug('Timing extension was called twice this %s, '
645
                             'log was not updated.', level)
646
                # Nothing to report for this level
647
                continue
648
649
            self.previous[level][action] = self.current[level][action]
650
            self.current[level][action] = profile['training', 'epoch', action]
651
652
            this_time = self.prefix + 'time_{}_this_{}'
653
            current_row[this_time.format(action, level)] = (
654
                (self.current[level][action] - self.previous[level][action]) /
655
                (current_index - previous_index))
656
            total_time = self.prefix + 'time_{}_total'
657
            current_row[total_time.format(action)] = \
658
                self.current[level][action]
659
660
661
class Timestamp(SimpleExtension):
662
    """Adds a human readable (ISO 8601) timestamp to the log.
663
664
    Parameters
665
    ----------
666
    log_record : str, optional
667
        The record name to use. Defaults to 'timestamp'.
668
    separator : str, optional
669
        Separator between the date and time. ISO 8601 specifies 'T'.
670
        Here, we default to ' ' (blank space) for human readability.
671
672
    """
673
    DEFAULT_LOG_RECORD = 'timestamp'
674
675
    def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ',
676
                 **kwargs):
677
        self.log_record = log_record
678
        self.separator = separator
679
        kwargs.setdefault('after_epoch', True)
680
        super(Timestamp, self).__init__(**kwargs)
681
682
    def do(self, *args):
683
        self.main_loop.log.current_row[self.log_record] = self.get_timestamp()
684
685
    def get_timestamp(self):
686
        # Separated into a method to override for ease of testing.
687
        return datetime.datetime.isoformat(self.separator)
688