Issues (119)

blocks/graph/bn.py (1 issue)

1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
from functools import partial
11
12
import theano
13
from toolz import isdistinct
14
15
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
16
from ..utils import find_bricks
17
18
19
def _training_mode_application_calls(application_calls):
20
    """Filter for application calls made in 'training mode'."""
21
    from ..bricks import BatchNormalization
22
    out = []
23
    for app_call in application_calls:
24
        assert isinstance(app_call.application.brick, BatchNormalization)
25
        assert app_call.application.application == BatchNormalization.apply
26
        if app_call.metadata.get('training_mode', False):
27
            out.append(app_call)
28
    return out
29
30
31
@contextlib.contextmanager
32
def batch_normalization(*bricks):
33
    r"""Context manager to run batch normalization in "training mode".
34
35
    Parameters
36
    ----------
37
    \*bricks
38
        One or more bricks which will be inspected for descendant
39
        instances of :class:`~blocks.bricks.BatchNormalization`.
40
41
    Notes
42
    -----
43
    Graph replacement using :func:`apply_batch_normalization`, while
44
    elegant, can lead to Theano graphs that are quite large and result
45
    in very slow compiles. This provides an alternative mechanism for
46
    building the batch normalized training graph. It can be somewhat
47
    less convenient as it requires building the graph twice if one
48
    wishes to monitor the output of the inference graph during training.
49
50
    Examples
51
    --------
52
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
53
54
    >>> import theano
55
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
56
    >>> from blocks.initialization import Constant, IsotropicGaussian
57
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
58
    ...                          weights_init=IsotropicGaussian(0.1),
59
    ...                          biases_init=Constant(0))
60
    >>> mlp.initialize()
61
62
    Now, we'll construct an output variable as we would normally. This
63
    is getting normalized by the *population* statistics, which by
64
    default are initialized to 0 (mean) and 1 (standard deviation),
65
    respectively.
66
67
    >>> x = theano.tensor.matrix()
68
    >>> y = mlp.apply(x)
69
70
    And now, to construct an output with batch normalization enabled,
71
    i.e. normalizing pre-activations using per-minibatch statistics, we
72
    simply make a similar call inside of a `with` statement:
73
74
    >>> with batch_normalization(mlp):
75
    ...     y_bn = mlp.apply(x)
76
77
    Let's verify that these two graphs behave differently on the
78
    same data:
79
80
    >>> import numpy
81
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
82
    >>> inf_y = y.eval({x: data})
83
    >>> trn_y = y_bn.eval({x: data})
84
    >>> numpy.allclose(inf_y, trn_y)
85
    False
86
87
    """
88
    # Avoid circular imports.
89
    from blocks.bricks import BatchNormalization
90
91
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
92
    # Can't use either nested() (deprecated) nor ExitStack (not available
93
    # on Python 2.7). Well, that sucks.
94
    try:
95
        for brick in bn:
96
            brick.__enter__()
97
        yield
98
    finally:
99
        for brick in bn[::-1]:
100
            brick.__exit__()
0 ignored issues
show
The loop variable brick might not be defined here.
Loading history...
101
102
103
def apply_batch_normalization(computation_graph):
104
    """Transform a graph into a batch-normalized training graph.
105
106
    Parameters
107
    ----------
108
    computation_graph : :class:`~blocks.graph.ComputationGraph`
109
        The computation graph containing :class:`BatchNormalization`
110
        brick applications.
111
112
    Returns
113
    -------
114
    batch_normed_graph : :class:`~blocks.graph.ComputationGraph`
115
        The computation graph, with :class:`BatchNormalization`
116
        applications transformed to use minibatch statistics instead
117
        of accumulated population statistics.
118
119
    See Also
120
    --------
121
    :func:`batch_normalization`, for an alternative method to produce
122
    batch normalized graphs.
123
124
    Examples
125
    --------
126
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
127
128
    >>> import theano
129
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
130
    >>> from blocks.initialization import Constant, IsotropicGaussian
131
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
132
    ...                          weights_init=IsotropicGaussian(0.1),
133
    ...                          biases_init=Constant(0))
134
    >>> mlp.initialize()
135
136
    Now, we'll construct an output variable as we would normally. This
137
    is getting normalized by the *population* statistics, which by
138
    default are initialized to 0 (mean) and 1 (standard deviation),
139
    respectively.
140
141
    >>> x = theano.tensor.matrix()
142
    >>> y = mlp.apply(x)
143
144
    Finally, we'll create a :class:`~blocks.graph.ComputationGraph`
145
    and transform it to switch to minibatch standardization:
146
147
    >>> from blocks.graph import ComputationGraph
148
    >>> cg = apply_batch_normalization(ComputationGraph([y]))
149
    >>> y_bn = cg.outputs[0]
150
151
    Let's verify that these two graphs behave differently on the
152
    same data:
153
154
    >>> import numpy
155
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
156
    >>> inf_y = y.eval({x: data})
157
    >>> trn_y = y_bn.eval({x: data})
158
    >>> numpy.allclose(inf_y, trn_y)
159
    False
160
161
    """
162
    # Avoid circular imports.
163
    from blocks.bricks import BatchNormalization
164
    from ..filter import VariableFilter, get_application_call
165
166
    # Create filters for variables involved in a batch normalization brick
167
    # application.
168
    def make_variable_filter(role):
169
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
170
171
    # Group inputs and outputs into dicts indexed by application call.
172
    def get_app_call_dict(variable_filter):
173
        return collections.OrderedDict((get_application_call(v), v) for v in
174
                                       variable_filter(computation_graph))
175
176
    # Compose these two so that we get 4 dicts, grouped by application
177
    # call, of different variable roles involved in BatchNormalization.
178
    inputs, outputs, means, stdevs = map(get_app_call_dict,
179
                                         map(make_variable_filter,
180
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
181
                                              BATCH_NORM_DIVISOR]))
182
183
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
184
185
    # Remove any ApplicationCalls that were not generated by apply(), or
186
    # were generated by an apply() while already in training mode.
187
    app_calls = inputs.keys()
188
    remove = _training_mode_application_calls(app_calls)
189
    for app_call in app_calls:
190
        if app_call in remove:
191
            for mapping in (inputs, outputs, means, stdevs):
192
                del mapping[app_call]
193
194
    replacements = []
195
    for app_call in inputs:
196
        old_output = outputs[app_call]
197
        # Get rid of the copy made on the way into the original apply.
198
        op = inputs[app_call].owner.op
199
        assert (isinstance(op, theano.tensor.Elemwise) and
200
                isinstance(op.scalar_op, theano.scalar.basic.Identity))
201
        unpacked = inputs[app_call].owner.inputs[0]
202
        with app_call.application.brick:
203
            new_output = app_call.application.brick.apply(unpacked)
204
        new_app_call = get_application_call(new_output)
205
        assert new_app_call.metadata['training_mode']
206
        replacements.append((old_output, new_output))
207
    return computation_graph.replace(replacements)
208
209
210
def get_batch_normalization_updates(training_graph, allow_duplicates=False):
211
    """Extract correspondences for learning BN population statistics.
212
213
    Parameters
214
    ----------
215
    training_graph : :class:`~blocks.graph.ComputationGraph`
216
        A graph of expressions wherein "training mode" batch normalization
217
        is taking place.
218
    allow_duplicates : bool, optional
219
        If `True`, allow multiple training-mode application calls from the
220
        same :class:`~blocks.bricks.BatchNormalization` instance, and
221
        return pairs corresponding to all of them. It's then the user's
222
        responsibility to do something sensible to resolve the duplicates.
223
224
    Returns
225
    -------
226
    update_pairs : list of tuples
227
        A list of 2-tuples where the first element of each tuple is the
228
        shared variable containing a "population" mean or standard
229
        deviation, and the second is a Theano variable for the
230
        corresponding statistics on a minibatch. Note that multiple
231
        applications of a single :class:`blocks.bricks.BatchNormalization`
232
        may appear in the graph, and therefore (if `allow_duplicates` is
233
        True) a single population variable may map to several different
234
        minibatch variables, and appear multiple times in this mapping.
235
        This can happen in recurrent models, siamese networks or other
236
        models that reuse pathways.
237
238
    Notes
239
    -----
240
    Used in their raw form, these updates will simply overwrite the
241
    population statistics with the minibatch statistics at every gradient
242
    step. You will probably want to transform these pairs into something
243
    more sensible, such as keeping a moving average of minibatch values,
244
    or accumulating an average over the entire training set once every few
245
    epochs.
246
247
    """
248
    from ..bricks import BatchNormalization
249
    from ..filter import VariableFilter, get_application_call
250
    var_filter = VariableFilter(bricks=[BatchNormalization], roles=[OUTPUT])
251
    all_app_calls = map(get_application_call, var_filter(training_graph))
252
    train_app_calls = _training_mode_application_calls(all_app_calls)
253
    if len(train_app_calls) == 0:
254
        raise ValueError("no training mode BatchNormalization "
255
                         "applications found in graph")
256
    bricks = [c.application.brick for c in train_app_calls]
257
258
    if not allow_duplicates and not isdistinct(bricks):
259
        raise ValueError('multiple applications of the same '
260
                         'BatchNormalization brick; pass allow_duplicates '
261
                         '= True to override this check')
262
263
    def extract_pair(brick_attribute, metadata_key, app_call):
264
        return (getattr(app_call.application.brick, brick_attribute),
265
                app_call.metadata[metadata_key])
266
267
    mean_pair = partial(extract_pair, 'population_mean', 'offset')
268
    stdev_pair = partial(extract_pair, 'population_stdev', 'divisor')
269
    return sum([[mean_pair(a), stdev_pair(a)]
270
                if not a.application.brick.mean_only
271
                else [mean_pair(a)]
272
                for a in train_app_calls], [])
273