Completed
Pull Request — master (#1136)
by
unknown
06:56 queued 02:01
created

Timing.__init__()   A

Complexity

Conditions 4

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 4
c 2
b 0
f 0
dl 0
loc 16
rs 9.2

1 Method

Rating   Name   Duplication   Size   Complexity  
A Timing.init_dict() 0 4 2
1
from __future__ import print_function
2
3
import logging
4
from abc import ABCMeta, abstractmethod
5
6
import progressbar
7
from six import add_metaclass
8
from toolz import first
9
10
logger = logging.getLogger(__name__)
11
12
13
def callback(func):
14
    func._is_callback = True
15
    return func
16
17
18
class TrainingExtension(object):
19
    """The base class for training extensions.
20
21
    An extension is a set of callbacks sharing a joint context that are
22
    invoked at certain stages of the training procedure. These callbacks
23
    typically add a certain functionality to the training procedure,
24
    e.g. running validation on auxiliary datasets or early stopping.
25
26
    Parameters
27
    ----------
28
    name : str, optional
29
        The name of the extension. The names are useful in order to
30
        distinguish between several extensions of the same type that
31
        belongs to the same main loop. By default the name is set to
32
        the name of the class.
33
34
    Attributes
35
    ----------
36
    main_loop : :class:`.MainLoop`
37
        The main loop to which the extension belongs.
38
    name : str
39
        The name of the extension.
40
41
    """
42
    def __init__(self, name=None):
43
        if not name:
44
            name = self.__class__.__name__
45
        self.name = name
46
47
    @property
48
    def main_loop(self):
49
        if not hasattr(self, '_main_loop'):
50
            raise ValueError("main loop must be assigned to extension first")
51
        return self._main_loop
52
53
    @main_loop.setter
54
    def main_loop(self, value):
55
        self._main_loop = value
56
57
    def dispatch(self, callback_name, *args):
58
        """Runs callback with the given name.
59
60
        The reason for having this method is to allow
61
        the descendants of the :class:`TrainingExtension` to intercept
62
        callback invocations and do something with them, e.g. block
63
        when certain condition does not hold. The default implementation
64
        simply invokes the callback by its name.
65
66
        """
67
        getattr(self, str(callback_name))(*args)
68
69
    @callback
70
    def on_resumption(self):
71
        """The callback invoked after training is resumed."""
72
        pass
73
74
    @callback
75
    def on_error(self, exception):
76
        """The callback invoked when an error occurs.
77
78
        Parameters
79
        ----------
80
        exception : object
81
            Exception occurred during the main loop run.
82
83
        """
84
85
    @callback
86
    def before_training(self):
87
        """The callback invoked before training is started."""
88
        pass
89
90
    @callback
91
    def before_epoch(self):
92
        """The callback invoked before starting an epoch."""
93
        pass
94
95
    @callback
96
    def before_batch(self, batch):
97
        """The callback invoked before a batch is processed.
98
99
        Parameters
100
        ----------
101
        batch : object
102
            The data batch to be processed.
103
104
        """
105
        pass
106
107
    @callback
108
    def after_batch(self, batch):
109
        """The callback invoked after a batch is processed.
110
111
        Parameters
112
        ----------
113
        batch : object
114
            The data batch just processed.
115
116
        """
117
        pass
118
119
    @callback
120
    def after_epoch(self):
121
        """The callback invoked after an epoch is finished."""
122
        pass
123
124
    @callback
125
    def after_training(self):
126
        """The callback invoked after training is finished."""
127
        pass
128
129
    @callback
130
    def on_interrupt(self):
131
        """The callback invoked when training is interrupted."""
132
        pass
133
134
135
class CallbackName(str):
136
    """A name of a TrainingExtension callback.
137
138
    Raises
139
    ------
140
    :class:`TypeError` on comparison with a string which is not a name of
141
    TrainingExtension callback.
142
143
    """
144
    def __eq__(self, other):
145
        callback_names = [key for key, value
146
                          in TrainingExtension.__dict__.items()
147
                          if getattr(value, '_is_callback', False)]
148
        if other not in callback_names:
149
            raise TypeError("{} is not a valid callback.".format(other))
150
        return str(self) == other
151
152
153
class Predicate(object):
154
    def __init__(self, condition, num):
155
        self.condition = condition
156
        self.num = num
157
158
    def __call__(self, log):
159
        if self.condition.endswith('epochs'):
160
            entry = log.status['epochs_done']
161
        else:
162
            entry = log.status['iterations_done']
163
        if self.condition.startswith('every'):
164
            return entry % self.num == 0
165
        else:
166
            return entry == self.num
167
168
169
def has_done_epochs(log):
170
    return log.status['epochs_done'] == 0
171
172
173
def always_true(log):
174
    return True
175
176
177
@add_metaclass(ABCMeta)
178
class SimpleExtension(TrainingExtension):
179
    """A base class for simple extensions.
180
181
    All logic of simple extensions is concentrated in the method
182
    :meth:`do`.  This method is called when certain conditions are
183
    fulfilled. The user can manage the conditions by calling the
184
    `add_condition` method and by passing arguments to the constructor.  In
185
    addition to specifying when :meth:`do` is called, it is possible to
186
    specify additional arguments passed to :meth:`do` under different
187
    conditions.
188
189
    Parameters
190
    ----------
191
    before_training : bool
192
        If ``True``, :meth:`do` is invoked before training.
193
    before_first_epoch : bool
194
        If ``True``, :meth:`do` is invoked before the first epoch.
195
    before_epoch : bool
196
        If ``True``, :meth:`do` is invoked before every epoch.
197
    on_resumption : bool, optional
198
        If ``True``, :meth:`do` is invoked when training is resumed.
199
    on_interrupt : bool, optional
200
        If ``True``, :meth:`do` is invoked when training is interrupted.
201
    after_epoch : bool
202
        If ``True``, :meth:`do` is invoked after every epoch.
203
    after_batch: bool
204
        If ``True``, :meth:`do` is invoked after every batch.
205
    after_training : bool
206
        If ``True``, :meth:`do` is invoked after training.
207
    after_n_epochs : int, optional
208
        If not ``None``, :meth:`do` is invoked when `after_n_epochs`
209
        epochs are done.
210
    every_n_epochs : int, optional
211
        If not ``None``, :meth:`do` is invoked after every n-th epoch.
212
    after_n_batches : int, optional
213
        If not ``None``, :meth:`do` is invoked when `after_n_batches`
214
        batches are processed.
215
    every_n_batches : int, optional
216
        If not ``None``, :meth:`do` is invoked after every n-th batch.
217
218
    """
219
    BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
220
                                  "before_epoch", "before_batch",
221
                                  "on_resumption", "on_interrupt",
222
                                  "after_epoch", "after_batch",
223
                                  "after_training"])
224
225
    INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches",
226
                                  "every_n_epochs", "every_n_batches"])
227
228
    def __init__(self, **kwargs):
229
        self._conditions = []
230
        super_kwargs = {}
231
        trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS
232
        conditions = {}
233
        for key, value in kwargs.items():
234
            if key in trigger_keywords:
235
                conditions[key] = value
236
            else:
237
                super_kwargs[key] = value
238
        self.set_conditions(**conditions)
239
        super(SimpleExtension, self).__init__(**super_kwargs)
240
241
    def set_conditions(self, **kwargs):
242
        """Set the conditions for which this extension should be run.
243
244
        Parameters
245
        ----------
246
        See the :class:`SimpleExtension` docstring for a list of
247
        possible parameters.
248
249
        """
250
        self._conditions[:] = []
251
        predicates = {'before_first_epoch': has_done_epochs}
252
        conditions = {
253
            'before_first_epoch': 'before_epoch',
254
            'after_epoch': 'after_epoch',
255
            'after_batch': 'after_batch',
256
            'every_n_batches': 'after_batch',
257
            'every_n_epochs': 'after_epoch',
258
            'after_n_batches': 'after_batch',
259
            'after_n_epochs': 'after_epoch'
260
        }
261
        # Freeze the keys as a list so that we can safely modify kwargs.
262
        for key, value in kwargs.items():
263
            if value:
264
                if key in self.BOOLEAN_TRIGGERS:
265
                    self.add_condition([conditions.get(key, key)],
266
                                       predicate=predicates.get(key, None))
267
                elif key in self.INTEGER_TRIGGERS:
268
                    predicate = Predicate(key, value)
269
                    self.add_condition([conditions.get(key, key)],
270
                                       predicate=predicate)
271
                else:
272
                    raise KeyError("Invalid condition: {}".format(key))
273
        return self  # For chaining calls.
274
275
    def add_condition(self, callbacks_names, predicate=None, arguments=None):
276
        """Adds a condition under which a :meth:`do` is called.
277
278
        Parameters
279
        ----------
280
        callbacks_names : list of str
281
            The names of the callback in which the method.
282
        predicate : function
283
            A predicate function the main loop's log as the
284
            single parameter and returning ``True`` when the method
285
            should be called and ``False`` when should not. If ``None``,
286
            an always ``True`` predicate is used.
287
        arguments : iterable
288
            Additional arguments to be passed to :meth:`do`. They will
289
            be concatenated with the ones passed from the main loop
290
            (e.g. the batch in case of `after_epoch` callback).
291
292
        Returns
293
        -------
294
            The extension object (allow chaining calls)
295
296
        """
297
        if not isinstance(callbacks_names, (list, tuple)):
298
            raise ValueError("callbacks_names must be list or tuple.")
299
        for _callback_name in callbacks_names:
300
            if not arguments:
301
                arguments = []
302
            if not predicate:
303
                self._conditions.append((_callback_name, always_true,
304
                                        arguments))
305
            else:
306
                self._conditions.append((_callback_name, predicate,
307
                                        arguments))
308
        return self
309
310
    @abstractmethod
311
    def do(self, which_callback, *args):
312
        r"""Does the job of the training extension.
313
314
        Parameters
315
        ----------
316
        which_callback : str
317
            The name of the callback in the context of which :meth:`do` is
318
            run.
319
        \*args : tuple
320
            The arguments from the main loop concatenated with additional
321
            arguments from user.
322
323
        Notes
324
        -----
325
        Subclasses *must* accept additional positional arguments in their
326
        call signature for this method, even if they are unused.
327
328
        """
329
        pass
330
331
    def dispatch(self, callback_invoked, *from_main_loop):
332
        """Check conditions and call the :meth:`do` method.
333
334
        Also adds additional arguments if specified for a condition.
335
336
        .. todo::
337
338
            Add a check for a situation when several conditions are met
339
            at the same time and do something.
340
341
        """
342
        for callback_name, predicate, arguments in self._conditions:
343
            if (callback_name == callback_invoked and
344
                    predicate(self.main_loop.log)):
345
                self.do(callback_invoked, *(from_main_loop + tuple(arguments)))
346
347
    @staticmethod
348
    def parse_args(which_callback, args):
349
        """Separates :meth:`do` arguments coming from different sources.
350
351
        When a :meth:`do` method receives arguments from both the main
352
        loop (e.g. a batch) and the user, it often has to separate them.
353
        This method is the right tool to use.
354
355
        Parameters
356
        ----------
357
        which_callback : str
358
            The name of the callback.
359
        args : iterable
360
            The arguments.
361
362
        Returns
363
        -------
364
        from_main_loop : tuple
365
        from_user : tuple
366
367
        """
368
        args = tuple(args)
369
        if (which_callback == 'after_batch' or
370
                which_callback == 'before_batch'):
371
            return (args[0],), args[1:]
372
        return (), args
373
374
375
class CompositeExtension(SimpleExtension):
376
    """An extension that manages several other extensions.
377
378
    Parameters
379
    ----------
380
    sub_extensions : iterable
381
        An iterable collection of sub-extensions to manage.
382
    run_before_children : bool, optional
383
        Whether the container extension's own logic should
384
        be dispatched before that of the sub-extensions.
385
        If ``False``, the containing extension is dispatched last.
386
        Defaults to ``True``.
387
388
    Notes
389
    -----
390
    The main use case for this class is bundling together groups
391
    of extensions that are most commonly used in tandem, configured
392
    so as to interact with one another. Encapsulating this pattern
393
    in a single extension reduces boilerplate.
394
395
    Sub-extensions are dispatched in the order specified in
396
    ``sub_extensions``, on whatever triggers they are individually
397
    configured to respect.
398
399
    Sub-extensions may be run on different triggers than the containing
400
    extension; the trigger keywords passed to the constructor
401
    for this class only affect the outer extension's logic, and
402
    sub-extensions should be configured independently (possibly in
403
    a constructor for a subclass of :class:`CompositeExtension`).
404
405
    """
406
    def __init__(self, sub_extensions, run_before_children=True, **kwargs):
407
        self.sub_extensions = sub_extensions
408
        self.run_before_children = run_before_children
409
        super(CompositeExtension, self).__init__(**kwargs)
410
411
    def dispatch(self, callback_invoked, *from_main_loop):
412
        def run_super():
413
            super(CompositeExtension, self).dispatch(callback_invoked,
414
                                                     *from_main_loop)
415
        if self.run_before_children:
416
            run_super()
417
418
        for ext in self.sub_extensions:
419
            ext.dispatch(callback_invoked, *from_main_loop)
420
421
        if not self.run_before_children:
422
            run_super()
423
424
    @property
425
    def main_loop(self):
426
        return super(CompositeExtension, self).main_loop
427
428
    @main_loop.setter
429
    def main_loop(self, value):
430
        self._main_loop = value
431
        for sub in self.sub_extensions:
432
            sub.main_loop = value
433
434
    def do(self, which_callback, *args):
435
        pass
436
437
438
class FinishAfter(SimpleExtension):
439
    """Finishes the training process when triggered."""
440
    def __init__(self, **kwargs):
441
        super(FinishAfter, self).__init__(**kwargs)
442
443
    def do(self, which_callback, *args):
444
        self.main_loop.log.current_row['training_finish_requested'] = True
445
446
447
class Printing(SimpleExtension):
448
    """Prints log messages to the screen."""
449
    def __init__(self, **kwargs):
450
        kwargs.setdefault("before_first_epoch", True)
451
        kwargs.setdefault("on_resumption", True)
452
        kwargs.setdefault("after_training", True)
453
        kwargs.setdefault("after_epoch", True)
454
        kwargs.setdefault("on_interrupt", True)
455
        super(Printing, self).__init__(**kwargs)
456
457
    def _print_attributes(self, attribute_tuples):
458
        for attr, value in sorted(attribute_tuples.items(), key=first):
459
            if not attr.startswith("_"):
460
                print("\t", "{}:".format(attr), value)
461
462
    def do(self, which_callback, *args):
463
        log = self.main_loop.log
464
        print_status = True
465
466
        print()
467
        print("".join(79 * "-"))
468
        if which_callback == "before_epoch" and log.status['epochs_done'] == 0:
469
            print("BEFORE FIRST EPOCH")
470
        elif which_callback == "on_resumption":
471
            print("TRAINING HAS BEEN RESUMED")
472
        elif which_callback == "after_training":
473
            print("TRAINING HAS BEEN FINISHED:")
474
        elif which_callback == "after_epoch":
475
            print("AFTER ANOTHER EPOCH")
476
        elif which_callback == "on_interrupt":
477
            print("TRAINING HAS BEEN INTERRUPTED")
478
            print_status = False
479
        print("".join(79 * "-"))
480
        if print_status:
481
            print("Training status:")
482
            self._print_attributes(log.status)
483
            print("Log records from the iteration {}:".format(
484
                log.status['iterations_done']))
485
            self._print_attributes(log.current_row)
486
        print()
487
488
489
class ProgressBar(TrainingExtension):
490
    """Display a progress bar during training.
491
492
    This extension tries to infer the number of iterations per epoch
493
    by querying the `num_batches`, `num_examples` and `batch_size`
494
    attributes from the :class:`IterationScheme`. When this information is
495
    not available it will display a simplified progress bar that does not
496
    include the estimated time until the end of this epoch.
497
498
    Notes
499
    -----
500
    This extension should be run before other extensions that print to
501
    the screen at the end or at the beginning of the epoch (e.g. the
502
    :class:`Printing` extension). Placing ProgressBar before these
503
    extension will ensure you won't get intermingled output on your
504
    terminal.
505
506
    """
507
    def __init__(self, **kwargs):
508
        super(ProgressBar, self).__init__(**kwargs)
509
        self.bar = None
510
        self.iter_count = 0
511
512
    def __getstate__(self):
513
        # Ensure we won't pickle the actual progress bar.
514
        # (It might contain unpicklable file handles)
515
        state = dict(self.__dict__)
516
        del state['bar']
517
        return state
518
519
    def __setstate__(self, state):
520
        self.__dict__.update(state)
521
        self.bar = None
522
523
    def get_iter_per_epoch(self):
524
        """Try to infer the number of iterations per epoch."""
525
        iter_scheme = self.main_loop.data_stream.iteration_scheme
526
        if hasattr(iter_scheme, 'num_batches'):
527
            return iter_scheme.num_batches
528
        elif (hasattr(iter_scheme, 'num_examples') and
529
                hasattr(iter_scheme, 'batch_size')):
530
            return iter_scheme.num_examples // iter_scheme.batch_size
531
        return None
532
533
    def create_bar(self):
534
        """Create a new progress bar.
535
536
        Calls `self.get_iter_per_epoch()`, selects an appropriate
537
        set of widgets and creates a ProgressBar.
538
539
        """
540
        iter_per_epoch = self.get_iter_per_epoch()
541
        epochs_done = self.main_loop.log.status['epochs_done']
542
543
        if iter_per_epoch is None:
544
            widgets = ["Epoch {}, step ".format(epochs_done),
545
                       progressbar.Counter(), ' ',
546
                       progressbar.BouncingBar(), ' ',
547
                       progressbar.Timer()]
548
            iter_per_epoch = progressbar.UnknownLength
549
        else:
550
            widgets = ["Epoch {}, step ".format(epochs_done),
551
                       progressbar.Counter(),
552
                       ' (', progressbar.Percentage(), ') ',
553
                       progressbar.Bar(), ' ',
554
                       progressbar.Timer(), ' ', progressbar.ETA()]
555
556
        return progressbar.ProgressBar(widgets=widgets,
557
                                       max_value=iter_per_epoch)
558
559
    def before_epoch(self):
560
        self.iter_count = 0
561
562
    def after_epoch(self):
563
        if self.bar is None:
564
            return
565
566
        self.bar.finish()
567
        self.bar = None
568
569
    def before_batch(self, batch):
570
        if self.bar is None:
571
            self.bar = self.create_bar()
572
            self.bar.start()
573
574
        self.iter_count += 1
575
        self.bar.update(self.iter_count)
576
577
578
class Timing(SimpleExtension):
579
    """Add timing information to the log.
580
581
    This adds data about the time spent in the algorithm's
582
    :meth:`~.Algorithm.process_batch` method as well as the time spent
583
    reading data per batch or epoch. It also reports the time spent
584
    initializing the algorithm.
585
586
    Parameters
587
    ----------
588
    prefix : str
589
        Prefix to be added to the log record. Defaults to the empty string.
590
591
    Notes
592
    -----
593
    Add this extension *before* the :class:`Printing` extension.
594
595
    Created with callbacks like ``every_n_batches`` this extension
596
    averages the time.
597
598
    This extension does *not* enable full profiling information. To see a
599
    full profile of the main loop at the end of training, use the
600
    ``profile`` configuration (e.g.  by setting ``BLOCKS_PROFILE=true``).
601
602
    """
603
    def __init__(self, prefix="", **kwargs):
604
        kwargs.setdefault('before_first_epoch', True)
605
        kwargs.setdefault('after_epoch', True)
606
        super(Timing, self).__init__(**kwargs)
607
608
        def init_dict():
609
            return {
610
                level: {'train': 0, 'read_data': 0}
611
                for level in ['batch', 'epoch']}
612
        self.current = init_dict()
613
        self.previous = init_dict()
614
        self.current_index = init_dict()
615
        self.previous_index = init_dict()
616
        self.prefix = prefix
617
        if self.prefix:
618
            self.prefix += '_'
619
620
    def do(self, which_callback, *args):
621
        current_row = self.main_loop.log.current_row
622
        profile = self.main_loop.profile.total
623
624
        if which_callback == 'before_epoch':
625
            current_row['time_initialization'] = profile[('initialization',)]
626
            return
627
        if which_callback == 'after_batch':
628
            level = 'batch'
629
            counter = 'iterations_done'
630
        elif which_callback == 'after_epoch':
631
            level = 'epoch'
632
            counter = 'epochs_done'
633
        else:
634
            raise ValueError('wrong callback type `{}`'.format(which_callback))
635
        for action in ['train', 'read_data']:
636
            self.previous_index[level][action] = (
637
                self.current_index[level][action])
638
            self.current_index[level][action] = (
639
                self.main_loop.log.status[counter])
640
            current_index = self.current_index[level]
641
            previous_index = self.previous_index[level]
642
            if current_index == previous_index:
643
                logger.debug('Timing extension was called twice this %s, '
644
                             'log was not updated.', level)
645
                # Nothing to report for this level
646
                continue
647
648
            self.previous[level][action] = self.current[level][action]
649
            self.current[level][action] = profile['training', 'epoch', action]
650
651
            this_time = self.prefix + 'time_{}_this_{}'
652
            current_row[this_time.format(action, level)] = (
653
                (self.current[level][action] - self.previous[level][action]) /
654
                (current_index - previous_index))
655
            total_time = self.prefix + 'time_{}_total'
656
            current_row[total_time.format(action)] = \
657
                self.current[level][action]
658