EarlyStopping   A
last analyzed

Complexity

Total Complexity 11

Size/Duplication

Total Lines 169
Duplicated Lines 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
c 2
b 0
f 0
dl 0
loc 169
rs 10
wmc 11

2 Methods

Rating   Name   Duplication   Size   Complexity  
B __init__() 0 23 6
B do() 0 20 5
1
import logging
2
3
from . import FinishAfter, CompositeExtension
4
from .training import TrackTheBest
5
from .predicates import OnLogRecord
6
7
8
logger = logging.getLogger(__name__)
9
10
11
class FinishIfNoImprovementAfter(FinishAfter):
12
    """Stop after improvements have ceased for a given period.
13
14
    Parameters
15
    ----------
16
    notification_name : str
17
        The name of the log record to look for which indicates a new
18
        best performer has been found.  Note that the value of this
19
        record is not inspected.
20
    iterations : int, optional
21
        The number of iterations to wait for a new best. Exactly one of
22
        `iterations` or `epochs` must be not `None` (default).
23
    epochs : int, optional
24
        The number of epochs to wait for a new best. Exactly one of
25
        `iterations` or `epochs` must be not `None` (default).
26
    patience_log_record : str, optional
27
        The name under which to record the number of iterations we
28
        are currently willing to wait for a new best performer.
29
        Defaults to `notification_name + '_patience_epochs'` or
30
        `notification_name + '_patience_iterations'`, depending
31
        which measure is being used.
32
33
    Notes
34
    -----
35
    By default, runs after each epoch. This can be manipulated via
36
    keyword arguments (see :class:`blocks.extensions.SimpleExtension`).
37
38
    """
39
    def __init__(self, notification_name, iterations=None, epochs=None,
40
                 patience_log_record=None, **kwargs):
41
        if (epochs is None) == (iterations is None):
42
            raise ValueError("Need exactly one of epochs or iterations "
43
                             "to be specified")
44
        self.notification_name = notification_name
45
        self.iterations = iterations
46
        self.epochs = epochs
47
        kwargs.setdefault('after_epoch', True)
48
        self.last_best_iter = self.last_best_epoch = None
49
        if patience_log_record is None:
50
            self.patience_log_record = (notification_name + '_patience' +
51
                                        ('_epochs' if self.epochs is not None
52
                                         else '_iterations'))
53
        else:
54
            self.patience_log_record = patience_log_record
55
        super(FinishIfNoImprovementAfter, self).__init__(**kwargs)
56
57
    def update_best(self):
58
        # Here mainly so we can easily subclass different criteria.
59
        if self.notification_name in self.main_loop.log.current_row:
60
            self.last_best_iter = self.main_loop.log.status['iterations_done']
61
            self.last_best_epoch = self.main_loop.log.status['epochs_done']
62
63
    def do(self, which_callback, *args):
64
        self.update_best()
65
        # If we haven't encountered a best yet, then we should just bail.
66
        if self.last_best_iter is None:
67
            return
68
        if self.epochs is not None:
69
            since = (self.main_loop.log.status['epochs_done'] -
70
                     self.last_best_epoch)
71
            patience = self.epochs - since
72
        else:
73
            since = (self.main_loop.log.status['iterations_done'] -
74
                     self.last_best_iter)
75
            patience = self.iterations - since
76
        logger.debug('%s: Writing patience of %d to current log record (%s) '
77
                     'at iteration %d', self.__class__.__name__, patience,
78
                     self.patience_log_record,
79
                     self.main_loop.log.status['iterations_done'])
80
        self.main_loop.log.current_row[self.patience_log_record] = patience
81
        if patience == 0:
82
            super(FinishIfNoImprovementAfter, self).do(which_callback,
83
                                                       *args)
84
85
86
class EarlyStopping(CompositeExtension):
87
    """A 'batteries-included' early stopping extension.
88
89
    Parameters
90
    ----------
91
    record_name : str
92
        The log record entry whose value represents the quantity to base
93
        early stopping decisions on, e.g. some measure of validation set
94
        performance.
95
    checkpoint_extension : :class:`~blocks.extensions.Checkpoint`, optional
96
        A :class:`~blocks.extensions.Checkpoint` instance to configure to
97
        save a checkpoint when a new best performer is found.
98
    checkpoint_filename : str, optional
99
        The filename to use for the 'current best' checkpoint. Must be
100
        provided if ``checkpoint_extension`` is specified.
101
    notification_name : str, optional
102
        The name to be written in the log when a new best-performing
103
        model is found. Defaults to ``record_name + '_best_so_far'``.
104
    choose_best : callable, optional
105
        See :class:`TrackTheBest`.
106
    iterations : int, optional
107
        See :class:`FinishIfNoImprovementAfter`.
108
    epochs : int, optional
109
        See :class:`FinishIfNoImprovementAfter`.
110
111
    Notes
112
    -----
113
    .. warning::
114
        If you want the best model to be saved, you need to specify
115
        a value for the ``checkpoint_extension`` and
116
        ``checkpoint_filename`` arguments!
117
118
    Trigger keyword arguments will affect how often the log is inspected
119
    for the record name (in order to determine if a new best has been
120
    found), as well as how often a decision is made about whether to
121
    continue training. By default, ``after_epoch`` is set,
122
    as is ``before_training``, where some sanity checks are performed
123
    (including the optional self-management of checkpointing).
124
125
    If ``checkpoint_extension`` is not in the main loop's extensions list
126
    when the `before_training` trigger is run, it will be added as a
127
    sub-extension of this object.
128
129
    Examples
130
    --------
131
    To simply track the best value of a log entry and halt training
132
    when it hasn't improved in a sufficient amount of time, we could
133
    use e.g.
134
135
    >>> stopping_ext = EarlyStopping('valid_error', iterations=100)
136
137
    which would halt training if a new minimum ``valid_error`` has not
138
    been achieved in 100 iterations (i.e. minibatches/steps). To measure
139
    in terms of epochs (which usually correspond to passes through the
140
    training set), you could use
141
142
    >>> epoch_stop_ext = EarlyStopping('valid_error', epochs=5)
143
144
    If you are tracking a log entry where there's a different definition
145
    of 'best', you can provide a callable that takes two log values
146
    and returns the one that :class:`EarlyStopping` should consider
147
    "better". For example, if you were tracking accuracy, where higher
148
    is better, you could pass the built-in ``max`` function:
149
150
    >>> max_acc_stop = EarlyStopping('valid_accuracy', choose_best=max,
151
    ...                              notification_name='highest_acc',
152
    ...                              epochs=10)
153
154
    Above we've also provided an alternate notification name, meaning
155
    a value of ``True`` will be written under the entry name
156
    ``highest_acc`` whenever a new highest accuracy is found (by default
157
    this would be a name like ``valid_accuracy_best_so_far``).
158
159
    Let's configure a checkpointing extension to save the model and log
160
    (but not the main loop):
161
162
    >>> from blocks.extensions.saveload import Checkpoint
163
    >>> checkpoint = Checkpoint('my_model.tar', save_main_loop=False,
164
    ...                         save_separately=['model', 'log'],
165
    ...                         after_epoch=True)
166
167
    When we pass this object to :class:`EarlyStopping`, along with a
168
    different filename, :class:`EarlyStopping` will configure that same
169
    checkpointing extension to *also* serialize to ``best_model.tar`` when
170
    a new best value of validation error is achieved.
171
172
    >>> stopping = EarlyStopping('valid_error', checkpoint,
173
    ...                          'best_model.tar', epochs=5)
174
175
    Finally, we'll set up the main loop:
176
177
    >>> from blocks.main_loop import MainLoop
178
    >>> # You would, of course, use a real algorithm and data stream here.
179
    >>> algorithm = data_stream = None
180
    >>> main_loop = MainLoop(algorithm=algorithm,
181
    ...                      data_stream=data_stream,
182
    ...                      extensions=[stopping, checkpoint])
183
184
    Note that you do want to place the checkpoint extension *after*
185
    the stopping extension, so that the appropriate log records
186
    have been written in order to trigger the checkpointing
187
    extension.
188
189
    It's also possible to in-line the creation of the
190
    checkpointing extension:
191
192
    >>> main_loop = MainLoop(algorithm=algorithm,
193
    ...                      data_stream=data_stream,
194
    ...                      extensions=[EarlyStopping(
195
    ...                          'valid_error',
196
    ...                          Checkpoint('my_model.tar',
197
    ...                                     save_main_loop=False,
198
    ...                                     save_separately=['model',
199
    ...                                                      'log'],
200
    ...                                     after_epoch=True),
201
    ...                          'my_best_model.tar',
202
    ...                          epochs=5)])
203
204
    Note that we haven't added the checkpointing extension to the
205
    main loop's extensions list. No problem: :class:`EarlyStopping` will
206
    detect that it isn't being managed by the main loop and manage it
207
    internally. It will automatically be executed in the right order
208
    for it to function properly alongside :class:`EarlyStopping`.
209
210
    """
211
    def __init__(self, record_name, checkpoint_extension=None,
212
                 checkpoint_filename=None, notification_name=None,
213
                 choose_best=min, iterations=None, epochs=None, **kwargs):
214
        if notification_name is None:
215
            notification_name = record_name + '_best_so_far'
216
        kwargs.setdefault('after_epoch', True)
217
        tracking_ext = TrackTheBest(record_name, notification_name,
218
                                    choose_best=choose_best, **kwargs)
219
        stopping_ext = FinishIfNoImprovementAfter(notification_name,
220
                                                  iterations=iterations,
221
                                                  epochs=epochs,
222
                                                  **kwargs)
223
        self.checkpoint_extension = checkpoint_extension
224
        if checkpoint_extension and checkpoint_filename:
225
            checkpoint_extension.add_condition(['after_batch'],
226
                                               OnLogRecord(notification_name),
227
                                               (checkpoint_filename,))
228
        elif checkpoint_extension is not None and checkpoint_filename is None:
229
            raise ValueError('checkpoint_extension specified without '
230
                             'checkpoint_filename')
231
        kwargs.setdefault('before_training', True)
232
        super(EarlyStopping, self).__init__([tracking_ext, stopping_ext],
233
                                            **kwargs)
234
235
    def do(self, which_callback, *args):
236
        if which_callback == 'before_training' and self.checkpoint_extension:
237
            if self.checkpoint_extension not in self.main_loop.extensions:
238
                logger.info('%s: checkpoint extension %s not in main loop '
239
                            'extensions, adding as sub-extension of %s',
240
                            self.__class__.__name__, self.checkpoint_extension,
241
                            self)
242
                self.checkpoint_extension.main_loop = self.main_loop
243
                self.sub_extensions.append(self.checkpoint_extension)
244
            else:
245
                exts = self.main_loop.extensions
246
                if exts.index(self.checkpoint_extension) < exts.index(self):
247
                    logger.warn('%s: configured checkpointing extension '
248
                                'appears after this extension in main loop '
249
                                'extensions list. This may lead to '
250
                                'unwanted results, as the notification '
251
                                'that would trigger serialization '
252
                                'of a new best will not have been '
253
                                'written yet when the checkpointing '
254
                                'extension is run.', self.__class__.__name__)
255