Completed
Pull Request — master (#1109)
by David
04:54
created

EarlyStopping.do()   B

Complexity

Conditions 5

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 5
c 1
b 0
f 0
dl 0
loc 13
rs 8.5454
1
import logging
2
3
from . import FinishAfter, CompositeExtension
4
from .training import TrackTheBest
5
from .predicates import OnLogRecord
6
7
8
logger = logging.getLogger(__name__)
9
10
11
class FinishIfNoImprovementAfter(FinishAfter):
12
    """Stop after improvements have ceased for a given period.
13
14
    Parameters
15
    ----------
16
    notification_name : str
17
        The name of the log record to look for which indicates a new
18
        best performer has been found.  Note that the value of this
19
        record is not inspected.
20
    iterations : int, optional
21
        The number of iterations to wait for a new best. Exactly one of
22
        `iterations` or `epochs` must be not `None` (default).
23
    epochs : int, optional
24
        The number of epochs to wait for a new best. Exactly one of
25
        `iterations` or `epochs` must be not `None` (default).
26
    patience_log_record : str, optional
27
        The name under which to record the number of iterations we
28
        are currently willing to wait for a new best performer.
29
        Defaults to `notification_name + '_patience_epochs'` or
30
        `notification_name + '_patience_iterations'`, depending
31
        which measure is being used.
32
33
    Notes
34
    -----
35
    By default, runs after each epoch. This can be manipulated via
36
    keyword arguments (see :class:`blocks.extensions.SimpleExtension`).
37
38
    """
39
    def __init__(self, notification_name, iterations=None, epochs=None,
40
                 patience_log_record=None, **kwargs):
41
        if (epochs is None) == (iterations is None):
42
            raise ValueError("Need exactly one of epochs or iterations "
43
                             "to be specified")
44
        self.notification_name = notification_name
45
        self.iterations = iterations
46
        self.epochs = epochs
47
        kwargs.setdefault('after_epoch', True)
48
        self.last_best_iter = self.last_best_epoch = None
49
        if patience_log_record is None:
50
            self.patience_log_record = (notification_name + '_patience' +
51
                                        ('_epochs' if self.epochs is not None
52
                                         else '_iterations'))
53
        else:
54
            self.patience_log_record = patience_log_record
55
        super(FinishIfNoImprovementAfter, self).__init__(**kwargs)
56
57
    def update_best(self):
58
        # Here mainly so we can easily subclass different criteria.
59
        if self.notification_name in self.main_loop.log.current_row:
60
            self.last_best_iter = self.main_loop.log.status['iterations_done']
61
            self.last_best_epoch = self.main_loop.log.status['epochs_done']
62
63
    def do(self, which_callback, *args):
64
        self.update_best()
65
        # If we haven't encountered a best yet, then we should just bail.
66
        if self.last_best_iter is None:
67
            return
68
        if self.epochs is not None:
69
            since = (self.main_loop.log.status['epochs_done'] -
70
                     self.last_best_epoch)
71
            patience = self.epochs - since
72
        else:
73
            since = (self.main_loop.log.status['iterations_done'] -
74
                     self.last_best_iter)
75
            patience = self.iterations - since
76
        logger.debug('%s: Writing patience of %d to current log record (%s) '
77
                     'at iteration %d', self.__class__.__name__, patience,
78
                     self.patience_log_record,
79
                     self.main_loop.log.status['iterations_done'])
80
        self.main_loop.log.current_row[self.patience_log_record] = patience
81
        if patience == 0:
82
            super(FinishIfNoImprovementAfter, self).do(which_callback,
83
                                                       *args)
84
85
86
class EarlyStopping(CompositeExtension):
87
    """A 'batteries-included' early stopping extension.
88
89
    Parameters
90
    ----------
91
    record_name : str
92
        The log record entry whose value represents the quantity to base
93
        early stopping decisions on, e.g. some measure of validation set
94
        performance.
95
    checkpoint_extension : :class:`~blocks.extensions.Checkpoint`, optional
96
        A :class:`~blocks.extensions.Checkpoint` instance to configure to
97
        save a checkpoint when a new best performer is found.
98
    checkpoint_filename : str, optional
99
        The filename to use for the 'current best' checkpoint. Must be
100
        provided if ``checkpoint_extension`` is specified.
101
    notification_name : str, optional
102
        The name to be written in the log when a new best-performing
103
        model is found. Defaults to ``record_name + '_best_so_far'``.
104
    choose_best : callable, optional
105
        See :class:`TrackTheBest`.
106
    iterations : int, optional
107
        See :class:`FinishIfNoImprovementAfter`.
108
    epochs : int, optional
109
        See :class:`FinishIfNoImprovementAfter`.
110
111
    Notes
112
    -----
113
    .. warning::
114
        If you want the best model to be saved, you need to specify
115
        a value for the ``checkpoint_extension`` and
116
        ``checkpoint_filename`` arguments!
117
118
    Trigger keyword arguments will affect how often the log is inspected
119
    for the record name (in order to determine if a new best has been
120
    found), as well as how often a decision is made about whether to
121
    continue training. By default, ``after_epoch`` is set,
122
    as is ``before_training``, where some sanity checks are performed
123
    (including the optional self-management of checkpointing).
124
125
    If ``checkpoint_extension`` is not in the main loop's extensions list
126
    when the `before_training` trigger is run, it will be added as a
127
    sub-extension of this object.
128
129
    Examples
130
    --------
131
    To simply track the best value of a log entry and halt training
132
    when it hasn't improved in a sufficient amount of time, we could
133
    use e.g.
134
135
    >>> stopping_ext = EarlyStopping('valid_error', iterations=100)
136
137
    which would halt training if a new minimum ``valid_error`` has not
138
    been achieved in 100 iterations (i.e. minibatches/steps). To measure
139
    in terms of epochs (which usually correspond to passes through the
140
    training set), you could use
141
142
    >>> epoch_stop_ext = EarlyStopping('valid_error', epoch=5)
143
144
    If you are tracking a log entry where there's a different definition
145
    of 'best', you can provide a callable that takes two log values
146
    and returns the one that :class:`EarlyStopping` should consider
147
    "better". For example, if you were tracking accuracy, where higher
148
    is better, you could pass the built-in ``max`` function:
149
150
    >>> max_acc_stop = EarlyStopping('valid_accuracy', choose_best=max,
151
    ...                              notification_name='highest_acc',
152
    ...                              epochs=10)
153
154
    Above we've also provided an alternate notification name, meaning
155
    a value of ``True`` will be written under the entry name
156
    ``highest_acc`` whenever a new highest accuracy is found (by default
157
    this would be a name like ``valid_accuracy_best_so_far``).
158
159
    Let's configure a checkpointing extension to save the model and log
160
    (but not the main loop):
161
162
    >>> checkpoint = Checkpoint('my_model.tar', save_main_loop=False,
163
    ...                         save_separately=['model', 'log'],
164
    ...                         after_epoch=True)
165
166
    When we pass this object to :class:`EarlyStopping`, along with a
167
    different filename, :class:`EarlyStopping` will configure that same
168
    checkpointing extension to *also* serialize to ``best_model.tar`` when
169
    a new best value of validation error is achieved.
170
171
    >>> stopping = EarlyStopping('valid_error', checkpoint,
172
    ...                          'best_model.tar', epochs=5)
173
174
    Finally, we'll set up the main loop:
175
176
    >>> from blocks.main_loop import MainLoop
177
    >>> # You would, of course, use a real algorithm and data stream here.
178
    >>> algorithm = data_stream = None
179
    >>> main_loop = MainLoop(algorithm=algorithm,
180
    ...                      data_stream=data_stream,
181
    ...                      extensions=[stopping, checkpoint])
182
183
    Note that you do want to place the checkpoint extension *after*
184
    the stopping extension, so that the appropriate log records
185
    have been written in order to trigger the checkpointing
186
    extension.
187
188
    It's also possible to in-line the creation of the
189
    checkpointing extension:
190
191
    >>> main_loop = MainLoop(algorithm=algorithm,
192
    ...                      data_stream=data_stream,
193
    ...                      extensions=[EarlyStopping(
194
    ...                          'valid_error',
195
    ...                          Checkpoint('my_model.tar',
196
    ...                                     save_main_loop=False,
197
    ...                                     save_separately=['model',
198
    ...                                                      'log'],
199
    ...                                     after_epoch=True),
200
    ...                          'my_best_model.tar',
201
    ...                          epochs=5]))
202
203
    Note that we haven't added the checkpointing extension to the
204
    main loop's extensions list. No problem: :class:`EarlyStopping` will
205
    detect that it isn't being managed by the main loop and manage it
206
    internally. It will automatically be executed in the right order
207
    for it to function properly alongside :class:`EarlyStopping`.
208
209
    """
210
    def __init__(self, record_name, checkpoint_extension=None,
211
                 checkpoint_filename=None, notification_name=None,
212
                 choose_best=min, iterations=None, epochs=None, **kwargs):
213
        if notification_name is None:
214
            notification_name = record_name + '_best_so_far'
215
        kwargs.setdefault('after_epoch', True)
216
        tracking_ext = TrackTheBest(record_name, notification_name,
217
                                    choose_best=choose_best, **kwargs)
218
        stopping_ext = FinishIfNoImprovementAfter(notification_name,
219
                                                  iterations=iterations,
220
                                                  epochs=epochs,
221
                                                  **kwargs)
222
        self.checkpoint_extension = checkpoint_extension
223
        if checkpoint_extension and checkpoint_filename:
224
            checkpoint_extension.add_condition(['after_batch'],
225
                                               OnLogRecord(notification_name),
226
                                               (checkpoint_filename,))
227
        elif checkpoint_extension is not None and checkpoint_filename is None:
228
            raise ValueError('checkpoint_extension specified without '
229
                             'checkpoint_filename')
230
        kwargs.setdefault('before_training', True)
231
        super(EarlyStopping, self).__init__([tracking_ext, stopping_ext],
232
                                            **kwargs)
233
234
    def do(self, which_callback, *args):
235
        if which_callback == 'before_training' and self.checkpoint_extension:
236
            if self.checkpoint_extension not in self.main_loop.extensions:
237
                logger.info('%s: checkpoint extension %s not in main loop '
238
                            'extensions, adding as sub-extension of %s',
239
                            self.__class__.__name__, self.checkpoint_extension,
240
                            self)
241
                self.checkpoint_extension.main_loop = self.main_loop
242
                self.sub_extensions.append(self.checkpoint_extension)
243
            else:
244
                exts = self.main_loop.extensions
245
                if exts.index(self.checkpoint_extension) < exts.index(self):
246
                    logger.warn('%s: configured checkpointing extension '
247
                                'appears after %s in main loop '
248
                                'extensions list. This may lead to '
249
                                'unwanted results, as the notification '
250
                                'that would trigger serialization '
251
                                'of a new best will not have been '
252
                                'written yet when the checkpointing '
253
                                'extension is run.')
254