Completed
Pull Request — master (#1109)
by David
04:53
created

EarlyStopping.__init__()   B

Complexity

Conditions 6

Size

Total Lines 23

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 23
rs 7.6949
1
import logging
2
3
from . import FinishAfter, CompositeExtension
4
from .training import TrackTheBest
5
from .predicates import OnLogRecord
6
7
8
logger = logging.getLogger(__name__)
9
10
11
class FinishIfNoImprovementAfter(FinishAfter):
12
    """Stop after improvements have ceased for a given period.
13
14
    Parameters
15
    ----------
16
    notification_name : str
17
        The name of the log record to look for which indicates a new
18
        best performer has been found.  Note that the value of this
19
        record is not inspected.
20
    iterations : int, optional
21
        The number of iterations to wait for a new best. Exactly one of
22
        `iterations` or `epochs` must be not `None` (default).
23
    epochs : int, optional
24
        The number of epochs to wait for a new best. Exactly one of
25
        `iterations` or `epochs` must be not `None` (default).
26
    patience_log_record : str, optional
27
        The name under which to record the number of iterations we
28
        are currently willing to wait for a new best performer.
29
        Defaults to `notification_name + '_patience_epochs'` or
30
        `notification_name + '_patience_iterations'`, depending
31
        which measure is being used.
32
33
    Notes
34
    -----
35
    By default, runs after each epoch. This can be manipulated via
36
    keyword arguments (see :class:`blocks.extensions.SimpleExtension`).
37
38
    """
39
    def __init__(self, notification_name, iterations=None, epochs=None,
40
                 patience_log_record=None, **kwargs):
41
        if (epochs is None) == (iterations is None):
42
            raise ValueError("Need exactly one of epochs or iterations "
43
                             "to be specified")
44
        self.notification_name = notification_name
45
        self.iterations = iterations
46
        self.epochs = epochs
47
        kwargs.setdefault('after_epoch', True)
48
        self.last_best_iter = self.last_best_epoch = None
49
        if patience_log_record is None:
50
            self.patience_log_record = (notification_name + '_patience' +
51
                                        ('_epochs' if self.epochs is not None
52
                                         else '_iterations'))
53
        else:
54
            self.patience_log_record = patience_log_record
55
        super(FinishIfNoImprovementAfter, self).__init__(**kwargs)
56
57
    def update_best(self):
58
        # Here mainly so we can easily subclass different criteria.
59
        if self.notification_name in self.main_loop.log.current_row:
60
            self.last_best_iter = self.main_loop.log.status['iterations_done']
61
            self.last_best_epoch = self.main_loop.log.status['epochs_done']
62
63
    def do(self, which_callback, *args):
64
        self.update_best()
65
        # If we haven't encountered a best yet, then we should just bail.
66
        if self.last_best_iter is None:
67
            return
68
        if self.epochs is not None:
69
            since = (self.main_loop.log.status['epochs_done'] -
70
                     self.last_best_epoch)
71
            patience = self.epochs - since
72
        else:
73
            since = (self.main_loop.log.status['iterations_done'] -
74
                     self.last_best_iter)
75
            patience = self.iterations - since
76
        logger.debug('%s: Writing patience of %d to current log record (%s) '
77
                     'at iteration %d', self.__class__.__name__, patience,
78
                     self.patience_log_record,
79
                     self.main_loop.log.status['iterations_done'])
80
        self.main_loop.log.current_row[self.patience_log_record] = patience
81
        if patience == 0:
82
            super(FinishIfNoImprovementAfter, self).do(which_callback,
83
                                                       *args)
84
85
86
class EarlyStopping(CompositeExtension):
87
    """A 'batteries-included' early stopping extension.
88
89
    Parameters
90
    ----------
91
    record_name : str
92
        The log record entry whose value represents the quantity to base
93
        early stopping decisions on, e.g. some measure of validation set
94
        performance.
95
    checkpoint_extension : :class:`~blocks.extensions.Checkpoint`, optional
96
        A :class:`~blocks.extensions.Checkpoint` instance to configure to
97
        save a checkpoint when a new best performer is found.
98
    checkpoint_filename : str, optional
99
        The filename to use for the 'current best' checkpoint. Must be
100
        provided if ``checkpoint_extension`` is specified.
101
    notification_name : str, optional
102
        The name to be written in the log when a new best-performing
103
        model is found. Defaults to ``record_name + '_best_so_far'``.
104
    choose_best : callable, optional
105
        See :class:`TrackTheBest`.
106
    iterations : int, optional
107
        See :class:`FinishIfNoImprovementAfter`.
108
    epochs : int, optional
109
        See :class:`FinishIfNoImprovementAfter`.
110
111
    Notes
112
    -----
113
    Trigger keyword arguments will affect how often the log is inspected
114
    for the record name (in order to determine if a new best has been
115
    found), as well as how often a decision is made about whether to
116
    continue training. By default, ``after_epoch`` is set,
117
    as is ``before_training``, where some sanity checks are performed
118
    (including the optional self-management of checkpointing).
119
120
    If ``checkpoint_extension`` is not in the main loop's extensions list
121
    when the `before_training` trigger is run, it will be added as a
122
    sub-extension of this object.
123
124
    """
125
126
    def __init__(self, record_name, checkpoint_extension=None,
127
                 checkpoint_filename=None, notification_name=None,
128
                 choose_best=min, iterations=None, epochs=None, **kwargs):
129
        if notification_name is None:
130
            notification_name = record_name + '_best_so_far'
131
        kwargs.setdefault('after_epoch', True)
132
        tracking_ext = TrackTheBest(record_name, notification_name,
133
                                    choose_best=choose_best, **kwargs)
134
        stopping_ext = FinishIfNoImprovementAfter(notification_name,
135
                                                  iterations=iterations,
136
                                                  epochs=epochs,
137
                                                  **kwargs)
138
        self.checkpoint_extension = checkpoint_extension
139
        if checkpoint_extension and checkpoint_filename:
140
            checkpoint_extension.add_condition(['after_batch'],
141
                                               OnLogRecord(notification_name),
142
                                               (checkpoint_filename,))
143
        elif checkpoint_extension is not None and checkpoint_filename is None:
144
            raise ValueError('checkpoint_extension specified without '
145
                             'checkpoint_filename')
146
        kwargs.setdefault('before_training', True)
147
        super(EarlyStopping, self).__init__([tracking_ext, stopping_ext],
148
                                            **kwargs)
149
150
    def do(self, which_callback, *args):
151
        if which_callback == 'before_training' and self.checkpoint_extension:
152
            if self.checkpoint_extension not in self.main_loop.extensions:
153
                logger.info('%s: checkpoint extension %s not in main loop '
154
                            'extensions, adding as sub-extension of %s',
155
                            self.__class__.__name__, self.checkpoint_extension,
156
                            self)
157
                self.checkpoint_extension.main_loop = self.main_loop
158
                self.sub_extensions.append(self.checkpoint_extension)
159
            else:
160
                exts = self.main_loop.extensions
161
                if exts.index(self.checkpoint_extension) < exts.index(self):
162
                    logger.warn('%s: configured checkpointing extension '
163
                                'appears after %s in main loop '
164
                                'extensions list. This may lead to '
165
                                'unwanted results, as the notification '
166
                                'that would trigger serialization '
167
                                'of a new best will not have been '
168
                                'written yet when the checkpointing '
169
                                'extension is run.')
170