Completed
Pull Request — master (#1110)
by
unknown
04:45
created

Timing.__init__()   B

Complexity

Conditions 6

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
dl 0
loc 23
rs 7.6949
c 0
b 0
f 0
1
from __future__ import print_function
2
3
import logging
4
from abc import ABCMeta, abstractmethod
5
6
import progressbar
7
from six import add_metaclass
8
from toolz import first
9
10
logger = logging.getLogger(__name__)
11
12
13
def callback(func):
14
    func._is_callback = True
15
    return func
16
17
18
class TrainingExtension(object):
19
    """The base class for training extensions.
20
21
    An extension is a set of callbacks sharing a joint context that are
22
    invoked at certain stages of the training procedure. These callbacks
23
    typically add a certain functionality to the training procedure,
24
    e.g. running validation on auxiliary datasets or early stopping.
25
26
    Parameters
27
    ----------
28
    name : str, optional
29
        The name of the extension. The names are useful in order to
30
        distinguish between several extensions of the same type that
31
        belongs to the same main loop. By default the name is set to
32
        the name of the class.
33
34
    Attributes
35
    ----------
36
    main_loop : :class:`.MainLoop`
37
        The main loop to which the extension belongs.
38
    name : str
39
        The name of the extension.
40
41
    """
42
    def __init__(self, name=None):
43
        if not name:
44
            name = self.__class__.__name__
45
        self.name = name
46
47
    @property
48
    def main_loop(self):
49
        if not hasattr(self, '_main_loop'):
50
            raise ValueError("main loop must be assigned to extension first")
51
        return self._main_loop
52
53
    @main_loop.setter
54
    def main_loop(self, value):
55
        self._main_loop = value
56
57
    def dispatch(self, callback_name, *args):
58
        """Runs callback with the given name.
59
60
        The reason for having this method is to allow
61
        the descendants of the :class:`TrainingExtension` to intercept
62
        callback invocations and do something with them, e.g. block
63
        when certain condition does not hold. The default implementation
64
        simply invokes the callback by its name.
65
66
        """
67
        getattr(self, str(callback_name))(*args)
68
69
    @callback
70
    def on_resumption(self):
71
        """The callback invoked after training is resumed."""
72
        pass
73
74
    @callback
75
    def on_error(self):
76
        """The callback invoked when an error occurs."""
77
        pass
78
79
    @callback
80
    def before_training(self):
81
        """The callback invoked before training is started."""
82
        pass
83
84
    @callback
85
    def before_epoch(self):
86
        """The callback invoked before starting an epoch."""
87
        pass
88
89
    @callback
90
    def before_batch(self, batch):
91
        """The callback invoked before a batch is processed.
92
93
        Parameters
94
        ----------
95
        batch : object
96
            The data batch to be processed.
97
98
        """
99
        pass
100
101
    @callback
102
    def after_batch(self, batch):
103
        """The callback invoked after a batch is processed.
104
105
        Parameters
106
        ----------
107
        batch : object
108
            The data batch just processed.
109
110
        """
111
        pass
112
113
    @callback
114
    def after_epoch(self):
115
        """The callback invoked after an epoch is finished."""
116
        pass
117
118
    @callback
119
    def after_training(self):
120
        """The callback invoked after training is finished."""
121
        pass
122
123
    @callback
124
    def on_interrupt(self):
125
        """The callback invoked when training is interrupted."""
126
        pass
127
128
129
class CallbackName(str):
130
    """A name of a TrainingExtension callback.
131
132
    Raises
133
    ------
134
    :class:`TypeError` on comparison with a string which is not a name of
135
    TrainingExtension callback.
136
137
    """
138
    def __eq__(self, other):
139
        callback_names = [key for key, value
140
                          in TrainingExtension.__dict__.items()
141
                          if getattr(value, '_is_callback', False)]
142
        if other not in callback_names:
143
            raise TypeError("{} is not a valid callback.".format(other))
144
        return str(self) == other
145
146
147
class Predicate(object):
148
    def __init__(self, condition, num):
149
        self.condition = condition
150
        self.num = num
151
152
    def __call__(self, log):
153
        if self.condition.endswith('epochs'):
154
            entry = log.status['epochs_done']
155
        else:
156
            entry = log.status['iterations_done']
157
        if self.condition.startswith('every'):
158
            return entry % self.num == 0
159
        else:
160
            return entry == self.num
161
162
163
def has_done_epochs(log):
164
    return log.status['epochs_done'] == 0
165
166
167
def always_true(log):
168
    return True
169
170
171
@add_metaclass(ABCMeta)
172
class SimpleExtension(TrainingExtension):
173
    """A base class for simple extensions.
174
175
    All logic of simple extensions is concentrated in the method
176
    :meth:`do`.  This method is called when certain conditions are
177
    fulfilled. The user can manage the conditions by calling the
178
    `add_condition` method and by passing arguments to the constructor.  In
179
    addition to specifying when :meth:`do` is called, it is possible to
180
    specify additional arguments passed to :meth:`do` under different
181
    conditions.
182
183
    Parameters
184
    ----------
185
    before_training : bool
186
        If ``True``, :meth:`do` is invoked before training.
187
    before_first_epoch : bool
188
        If ``True``, :meth:`do` is invoked before the first epoch.
189
    before_epoch : bool
190
        If ``True``, :meth:`do` is invoked before every epoch.
191
    on_resumption : bool, optional
192
        If ``True``, :meth:`do` is invoked when training is resumed.
193
    on_interrupt : bool, optional
194
        If ``True``, :meth:`do` is invoked when training is interrupted.
195
    after_epoch : bool
196
        If ``True``, :meth:`do` is invoked after every epoch.
197
    after_batch: bool
198
        If ``True``, :meth:`do` is invoked after every batch.
199
    after_training : bool
200
        If ``True``, :meth:`do` is invoked after training.
201
    after_n_epochs : int, optional
202
        If not ``None``, :meth:`do` is invoked when `after_n_epochs`
203
        epochs are done.
204
    every_n_epochs : int, optional
205
        If not ``None``, :meth:`do` is invoked after every n-th epoch.
206
    after_n_batches : int, optional
207
        If not ``None``, :meth:`do` is invoked when `after_n_batches`
208
        batches are processed.
209
    every_n_batches : int, optional
210
        If not ``None``, :meth:`do` is invoked after every n-th batch.
211
212
    """
213
    BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch",
214
                                  "before_epoch", "on_resumption",
215
                                  "on_interrupt", "after_epoch",
216
                                  "after_batch", "after_training"])
217
218
    INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches",
219
                                  "every_n_epochs", "every_n_batches"])
220
221
    def __init__(self, **kwargs):
222
        self._conditions = []
223
        super_kwargs = {}
224
        trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS
225
        conditions = {}
226
        for key, value in kwargs.items():
227
            if key in trigger_keywords:
228
                conditions[key] = value
229
            else:
230
                super_kwargs[key] = value
231
        self.set_conditions(**conditions)
232
        super(SimpleExtension, self).__init__(**super_kwargs)
233
234
    def set_conditions(self, **kwargs):
235
        """Set the conditions for which this extension should be run.
236
237
        Parameters
238
        ----------
239
        See the :class:`SimpleExtension` docstring for a list of
240
        possible parameters.
241
242
        """
243
        self._conditions[:] = []
244
        predicates = {'before_first_epoch': has_done_epochs}
245
        conditions = {
246
            'before_first_epoch': 'before_epoch',
247
            'after_epoch': 'after_epoch',
248
            'after_batch': 'after_batch',
249
            'every_n_batches': 'after_batch',
250
            'every_n_epochs': 'after_epoch',
251
            'after_n_batches': 'after_batch',
252
            'after_n_epochs': 'after_epoch'
253
        }
254
        # Freeze the keys as a list so that we can safely modify kwargs.
255
        for key, value in kwargs.items():
256
            if value:
257
                if key in self.BOOLEAN_TRIGGERS:
258
                    self.add_condition([conditions.get(key, key)],
259
                                       predicate=predicates.get(key, None))
260
                elif key in self.INTEGER_TRIGGERS:
261
                    predicate = Predicate(key, value)
262
                    self.add_condition([conditions.get(key, key)],
263
                                       predicate=predicate)
264
                else:
265
                    raise KeyError("Invalid condition: {}".format(key))
266
        return self  # For chaining calls.
267
268
    def add_condition(self, callbacks_names, predicate=None, arguments=None):
269
        """Adds a condition under which a :meth:`do` is called.
270
271
        Parameters
272
        ----------
273
        callbacks_names : list of str
274
            The names of the callback in which the method.
275
        predicate : function
276
            A predicate function the main loop's log as the
277
            single parameter and returning ``True`` when the method
278
            should be called and ``False`` when should not. If ``None``,
279
            an always ``True`` predicate is used.
280
        arguments : iterable
281
            Additional arguments to be passed to :meth:`do`. They will
282
            be concatenated with the ones passed from the main loop
283
            (e.g. the batch in case of `after_epoch` callback).
284
285
        Returns
286
        -------
287
            The extension object (allow chaining calls)
288
289
        """
290
        if not isinstance(callbacks_names, (list, tuple)):
291
            raise ValueError("callbacks_names must be list or tuple.")
292
        for _callback_name in callbacks_names:
293
            if not arguments:
294
                arguments = []
295
            if not predicate:
296
                self._conditions.append((_callback_name, always_true,
297
                                        arguments))
298
            else:
299
                self._conditions.append((_callback_name, predicate,
300
                                        arguments))
301
        return self
302
303
    @abstractmethod
304
    def do(self, which_callback, *args):
305
        r"""Does the job of the training extension.
306
307
        Parameters
308
        ----------
309
        which_callback : str
310
            The name of the callback in the context of which :meth:`do` is
311
            run.
312
        \*args : tuple
313
            The arguments from the main loop concatenated with additional
314
            arguments from user.
315
316
        Notes
317
        -----
318
        Subclasses *must* accept additional positional arguments in their
319
        call signature for this method, even if they are unused.
320
321
        """
322
        pass
323
324
    def dispatch(self, callback_invoked, *from_main_loop):
325
        """Check conditions and call the :meth:`do` method.
326
327
        Also adds additional arguments if specified for a condition.
328
329
        .. todo::
330
331
            Add a check for a situation when several conditions are met
332
            at the same time and do something.
333
334
        """
335
        for callback_name, predicate, arguments in self._conditions:
336
            if (callback_name == callback_invoked and
337
                    predicate(self.main_loop.log)):
338
                self.do(callback_invoked, *(from_main_loop + tuple(arguments)))
339
340
    @staticmethod
341
    def parse_args(which_callback, args):
342
        """Separates :meth:`do` arguments coming from different sources.
343
344
        When a :meth:`do` method receives arguments from both the main
345
        loop (e.g. a batch) and the user, it often has to separate them.
346
        This method is the right tool to use.
347
348
        Parameters
349
        ----------
350
        which_callback : str
351
            The name of the callback.
352
        args : iterable
353
            The arguments.
354
355
        Returns
356
        -------
357
        from_main_loop : tuple
358
        from_user : tuple
359
360
        """
361
        args = tuple(args)
362
        if (which_callback == 'after_batch' or
363
                which_callback == 'before_batch'):
364
            return (args[0],), args[1:]
365
        return (), args
366
367
368
class FinishAfter(SimpleExtension):
369
    """Finishes the training process when triggered."""
370
    def __init__(self, **kwargs):
371
        super(FinishAfter, self).__init__(**kwargs)
372
373
    def do(self, which_callback, *args):
374
        self.main_loop.log.current_row['training_finish_requested'] = True
375
376
377
class Printing(SimpleExtension):
378
    """Prints log messages to the screen."""
379
    def __init__(self, **kwargs):
380
        kwargs.setdefault("before_first_epoch", True)
381
        kwargs.setdefault("on_resumption", True)
382
        kwargs.setdefault("after_training", True)
383
        kwargs.setdefault("after_epoch", True)
384
        kwargs.setdefault("on_interrupt", True)
385
        super(Printing, self).__init__(**kwargs)
386
387
    def _print_attributes(self, attribute_tuples):
388
        for attr, value in sorted(attribute_tuples.items(), key=first):
389
            if not attr.startswith("_"):
390
                print("\t", "{}:".format(attr), value)
391
392
    def do(self, which_callback, *args):
393
        log = self.main_loop.log
394
        print_status = True
395
396
        print()
397
        print("".join(79 * "-"))
398
        if which_callback == "before_epoch" and log.status['epochs_done'] == 0:
399
            print("BEFORE FIRST EPOCH")
400
        elif which_callback == "on_resumption":
401
            print("TRAINING HAS BEEN RESUMED")
402
        elif which_callback == "after_training":
403
            print("TRAINING HAS BEEN FINISHED:")
404
        elif which_callback == "after_epoch":
405
            print("AFTER ANOTHER EPOCH")
406
        elif which_callback == "on_interrupt":
407
            print("TRAINING HAS BEEN INTERRUPTED")
408
            print_status = False
409
        print("".join(79 * "-"))
410
        if print_status:
411
            print("Training status:")
412
            self._print_attributes(log.status)
413
            print("Log records from the iteration {}:".format(
414
                log.status['iterations_done']))
415
            self._print_attributes(log.current_row)
416
        print()
417
418
419
class ProgressBar(TrainingExtension):
420
    """Display a progress bar during training.
421
422
    This extension tries to infer the number of iterations per epoch
423
    by querying the `num_batches`, `num_examples` and `batch_size`
424
    attributes from the :class:`IterationScheme`. When this information is
425
    not available it will display a simplified progress bar that does not
426
    include the estimated time until the end of this epoch.
427
428
    Notes
429
    -----
430
    This extension should be run before other extensions that print to
431
    the screen at the end or at the beginning of the epoch (e.g. the
432
    :class:`Printing` extension). Placing ProgressBar before these
433
    extension will ensure you won't get intermingled output on your
434
    terminal.
435
436
    """
437
    def __init__(self, **kwargs):
438
        super(ProgressBar, self).__init__(**kwargs)
439
        self.bar = None
440
        self.iter_count = 0
441
442
    def __getstate__(self):
443
        # Ensure we won't pickle the actual progress bar.
444
        # (It might contain unpicklable file handles)
445
        state = dict(self.__dict__)
446
        del state['bar']
447
        return state
448
449
    def __setstate__(self, state):
450
        self.__dict__.update(state)
451
        self.bar = None
452
453
    def get_iter_per_epoch(self):
454
        """Try to infer the number of iterations per epoch."""
455
        iter_scheme = self.main_loop.data_stream.iteration_scheme
456
        if hasattr(iter_scheme, 'num_batches'):
457
            return iter_scheme.num_batches
458
        elif (hasattr(iter_scheme, 'num_examples') and
459
                hasattr(iter_scheme, 'batch_size')):
460
            return iter_scheme.num_examples // iter_scheme.batch_size
461
        return None
462
463
    def create_bar(self):
464
        """Create a new progress bar.
465
466
        Calls `self.get_iter_per_epoch()`, selects an appropriate
467
        set of widgets and creates a ProgressBar.
468
469
        """
470
        iter_per_epoch = self.get_iter_per_epoch()
471
        epochs_done = self.main_loop.log.status['epochs_done']
472
473
        if iter_per_epoch is None:
474
            widgets = ["Epoch {}, step ".format(epochs_done),
475
                       progressbar.Counter(), ' ',
476
                       progressbar.BouncingBar(), ' ',
477
                       progressbar.Timer()]
478
            iter_per_epoch = progressbar.UnknownLength
479
        else:
480
            widgets = ["Epoch {}, step ".format(epochs_done),
481
                       progressbar.Counter(),
482
                       ' (', progressbar.Percentage(), ') ',
483
                       progressbar.Bar(), ' ',
484
                       progressbar.Timer(), ' ', progressbar.ETA()]
485
486
        return progressbar.ProgressBar(widgets=widgets,
487
                                       max_value=iter_per_epoch)
488
489
    def before_epoch(self):
490
        self.iter_count = 0
491
492
    def after_epoch(self):
493
        if self.bar is None:
494
            return
495
496
        self.bar.finish()
497
        self.bar = None
498
499
    def before_batch(self, batch):
500
        if self.bar is None:
501
            self.bar = self.create_bar()
502
            self.bar.start()
503
504
        self.iter_count += 1
505
        self.bar.update(self.iter_count)
506
507
508
class Timing(SimpleExtension):
509
    """Add timing information to the log.
510
511
    This adds data about the time spent in the algorithm's
512
    :meth:`~.Algorithm.process_batch` method as well as the time spent
513
    reading data per batch or epoch. It also reports the time spent
514
    initializing the algorithm.
515
516
    Parameters
517
    ----------
518
    prefix : str
519
        Prefix to be added to the log record. Defaults to the empty string.
520
521
    Notes
522
    -----
523
    Add this extension *before* the :class:`Printing` extension.
524
525
    Created with callbacks like ``every_n_batches`` this extension
526
    averages the time.
527
528
    This extension does *not* enable full profiling information. To see a
529
    full profile of the main loop at the end of training, use the
530
    ``profile`` configuration (e.g.  by setting ``BLOCKS_PROFILE=true``).
531
532
    """
533
    def __init__(self, prefix="", **kwargs):
534
        kwargs.setdefault('before_first_epoch', True)
535
        kwargs.setdefault('after_epoch', True)
536
        super(Timing, self).__init__(**kwargs)
537
        self.current = {
538
            level: {'train': 0, 'read_data': 0}
539
            for level in ['batch', 'epoch']
540
        }
541
        self.previous = {
542
            level: {'train': 0, 'read_data': 0}
543
            for level in ['batch', 'epoch']
544
        }
545
        self.current_index = {
546
            level: 0
547
            for level in ['batch', 'epoch']
548
            }
549
        self.previous_index = {
550
            level: 0
551
            for level in ['batch', 'epoch']
552
        }
553
        self.prefix = prefix
554
        if self.prefix:
555
            self.prefix += '_'
556
557
    def do(self, which_callback, *args):
558
        current_row = self.main_loop.log.current_row
559
        profile = self.main_loop.profile.total
560
561
        if which_callback == 'before_epoch':
562
            current_row['time_initialization'] = profile[('initialization',)]
563
            return
564
        if which_callback == 'after_batch':
565
            level = 'batch'
566
            counter = 'iterations_done'
567
        elif which_callback == 'after_epoch':
568
            level = 'epoch'
569
            counter = 'epochs_done'
570
        for action in ['train', 'read_data']:
571
            self.previous_index[level] = self.current_index[level]
572
            self.current_index[level] = self.main_loop.log.status[counter]
573
            if self.current_index[level] == self.previous_index[level]:
574
                logger.debug('Timing extension was called twice this %s, '
575
                             'log was not updated.', level)
576
                # Nothing to report for this level
577
                continue
578
579
            self.previous[level][action] = self.current[level][action]
580
            self.current[level][action] = profile['training', 'epoch', action]
581
582
            this_time = self.prefix + 'time_{}_this_{}'
583
            current_row[this_time.format(action, level)] = (
584
                (self.current[level][action] - self.previous[level][action]) /
585
                (self.current_index[level] - self.previous_index[level]))
586
            total_time = self.prefix + 'time_{}_total'
587
            current_row[total_time.format(action)] = \
588
                self.current[level][action]
589