Completed
Pull Request — master (#1109)
by David
04:48
created

EarlyStopping   A

Complexity

Total Complexity 11

Size/Duplication

Total Lines 76
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 76
rs 10
wmc 11

2 Methods

Rating   Name   Duplication   Size   Complexity  
B do() 0 13 5
B __init__() 0 23 6
1
import logging
2
3
from . import FinishAfter, CompositeExtension
4
from .training import TrackTheBest
5
from .predicates import OnLogRecord
6
7
8
logger = logging.getLogger(__name__)
9
10
11
class FinishIfNoImprovementAfter(FinishAfter):
12
    """Stop after improvements have ceased for a given period.
13
14
    Parameters
15
    ----------
16
    notification_name : str
17
        The name of the log record to look for which indicates a new
18
        best performer has been found.  Note that the value of this
19
        record is not inspected.
20
    iterations : int, optional
21
        The number of iterations to wait for a new best. Exactly one of
22
        `iterations` or `epochs` must be not `None` (default).
23
    epochs : int, optional
24
        The number of epochs to wait for a new best. Exactly one of
25
        `iterations` or `epochs` must be not `None` (default).
26
    patience_log_record : str, optional
27
        The name under which to record the number of iterations we
28
        are currently willing to wait for a new best performer.
29
        Defaults to `notification_name + '_patience_epochs'` or
30
        `notification_name + '_patience_iterations'`, depending
31
        which measure is being used.
32
33
    Notes
34
    -----
35
    By default, runs after each epoch. This can be manipulated via
36
    keyword arguments (see :class:`blocks.extensions.SimpleExtension`).
37
38
    """
39
    def __init__(self, notification_name, iterations=None, epochs=None,
40
                 patience_log_record=None, **kwargs):
41
        if (epochs is None) == (iterations is None):
42
            raise ValueError("Need exactly one of epochs or iterations "
43
                             "to be specified")
44
        self.notification_name = notification_name
45
        self.iterations = iterations
46
        self.epochs = epochs
47
        kwargs.setdefault('after_epoch', True)
48
        self.last_best_iter = self.last_best_epoch = None
49
        if patience_log_record is None:
50
            self.patience_log_record = (notification_name + '_patience' +
51
                                        ('_epochs' if self.epochs is not None
52
                                         else '_iterations'))
53
        else:
54
            self.patience_log_record = patience_log_record
55
        super(FinishIfNoImprovementAfter, self).__init__(**kwargs)
56
57
    def update_best(self):
58
        # Here mainly so we can easily subclass different criteria.
59
        if self.notification_name in self.main_loop.log.current_row:
60
            self.last_best_iter = self.main_loop.log.status['iterations_done']
61
            self.last_best_epoch = self.main_loop.log.status['epochs_done']
62
63
    def do(self, which_callback, *args):
64
        self.update_best()
65
        # If we haven't encountered a best yet, then we should just bail.
66
        if self.last_best_iter is None:
67
            return
68
        if self.epochs is not None:
69
            since = (self.main_loop.log.status['epochs_done'] -
70
                     self.last_best_epoch)
71
            patience = self.epochs - since
72
        else:
73
            since = (self.main_loop.log.status['iterations_done'] -
74
                     self.last_best_iter)
75
            patience = self.iterations - since
76
        logger.debug('%s: Writing patience of %d to current log record (%s) '
77
                     'at iteration %d', self.__class__.__name__, patience,
78
                     self.patience_log_record,
79
                     self.main_loop.log.status['iterations_done'])
80
        self.main_loop.log.current_row[self.patience_log_record] = patience
81
        if patience == 0:
82
            super(FinishIfNoImprovementAfter, self).do(which_callback,
83
                                                       *args)
84
85
86
class EarlyStopping(CompositeExtension):
87
    """A 'batteries-included' early stopping extension.
88
89
    Parameters
90
    ----------
91
    record_name : str
92
        The log record entry whose value represents the quantity to base
93
        early stopping decisions on, e.g. some measure of validation set
94
        performance.
95
    checkpoint_extension : :class:`~blocks.extensions.Checkpoint`, optional
96
        A :class:`~blocks.extensions.Checkpoint` instance to configure to
97
        save a checkpoint when a new best performer is found.
98
    checkpoint_filename : str, optional
99
        The filename to use for the 'current best' checkpoint. Must be
100
        provided if ``checkpoint_extension`` is specified.
101
    notification_name : str, optional
102
        The name to be written in the log when a new best-performing
103
        model is found. Defaults to ``record_name + '_best_so_far'``.
104
    choose_best : callable, optional
105
        See :class:`TrackTheBest`.
106
    iterations : int, optional
107
        See :class:`FinishIfNoImprovementAfter`.
108
    epochs : int, optional
109
        See :class:`FinishIfNoImprovementAfter`.
110
111
    Notes
112
    -----
113
    Trigger keyword arguments will affect how often the log is inspected
114
    for the record name (in order to determine if a new best has been
115
    found), as well as how often. By default, ``after_epoch`` is set,
116
    as is ``before_training``, where some sanity checks are performed
117
    (including the optional self-management of checkpointing).
118
119
    If ``checkpoint_extension`` is not in the main loop's extensions list
120
    when the `before_training` trigger is run, it will be added as a
121
    sub-extension of this object.
122
123
    """
124
125
    def __init__(self, record_name, checkpoint_extension=None,
126
                 checkpoint_filename=None, notification_name=None,
127
                 choose_best=min, iterations=None, epochs=None, **kwargs):
128
        if notification_name is None:
129
            notification_name = record_name + '_best_so_far'
130
        kwargs.setdefault('after_epoch', True)
131
        tracking_ext = TrackTheBest(record_name, notification_name,
132
                                    choose_best=choose_best, **kwargs)
133
        stopping_ext = FinishIfNoImprovementAfter(notification_name,
134
                                                  iterations=iterations,
135
                                                  epochs=epochs,
136
                                                  **kwargs)
137
        self.checkpoint_extension = checkpoint_extension
138
        if checkpoint_extension and checkpoint_filename:
139
            checkpoint_extension.add_condition(['after_batch'],
140
                                               OnLogRecord(notification_name),
141
                                               (checkpoint_filename,))
142
        elif checkpoint_extension is not None and checkpoint_filename is None:
143
            raise ValueError('checkpoint_extension specified without '
144
                             'checkpoint_filename')
145
        kwargs.setdefault('before_training', True)
146
        super(EarlyStopping, self).__init__([tracking_ext, stopping_ext],
147
                                            **kwargs)
148
149
    def do(self, which_callback, *args):
150
        if which_callback == 'before_training' and self.checkpoint_extension:
151
            if self.checkpoint_extension not in self.main_loop.extensions:
152
                logger.info('%s: checkpoint extension %s not in main loop '
153
                            'extensions, adding as sub-extension of %s',
154
                            self.__class__.__name__, self.checkpoint_extension,
155
                            self)
156
                self.checkpoint_extension.main_loop = self.main_loop
157
                self.sub_extensions.append(self.checkpoint_extension)
158
            else:
159
                exts = self.main_loop.extensions
160
                if exts.index(self.checkpoint_extension) < exts.index(self):
161
                    logger.warn('%s: configured checkpointing extension '
162
                                'appears after %s in main loop '
163
                                'extensions list. This may lead to '
164
                                'unwanted results, as the notification '
165
                                'that would trigger serialization '
166
                                'of a new best will not have been '
167
                                'written yet when the checkpointing '
168
                                'extension is run.')
169