Completed
Pull Request — master (#941)
by David
01:56
created

blocks.bricks.BatchNormalization   A

Complexity

Total Complexity 26

Size/Duplication

Total Lines 172
Duplicated Lines 0 %
Metric Value
dl 0
loc 172
rs 10
wmc 26

12 Methods

Rating   Name   Duplication   Size   Complexity  
A apply() 0 19 3
A get_dim() 0 5 2
A image_size() 0 6 2
A num_channels() 0 6 1
A W() 0 3 1
C _allocate() 0 35 7
A output_dim() 0 3 1
A _initialize() 0 3 1
A num_output_channels() 0 3 1
A b() 0 3 1
A annotate() 0 4 1
A __init__() 0 12 3
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
from theano import tensor
6
from theano.tensor.nnet import bn
7
8
from ..graph import add_annotation
9
from ..initialization import Constant
10
from ..filter import VariableFilter, get_application_call
11
from ..roles import (INPUT, WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN,
12
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
13
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
14
                     add_role)
15
from ..utils import (shared_floatx_zeros, shared_floatx,
16
                     shared_floatx_nans)
17
from .base import lazy, application
18
from .sequences import Sequence, Feedforward, MLP
19
from .interfaces import RNGMixin
20
21
22
class BatchNormalization(RNGMixin, Feedforward):
23
    """Normalizes activations, parameterizes a scale and shift.
24
25
    Parameters
26
    ----------
27
    input_dim : int or tuple
28
        Shape of a single input example. It is assumed that a batch axis
29
        will be prepended to this.
30
    broadcastable : tuple, optional
31
        Tuple the same length as `input_dim` which specifies which of the
32
        per-example axes should be averaged over to compute means and
33
        standard deviations. For example, in order to normalize over all
34
        spatial locations in a `(batch_index, channels, height, width)`
35
        image, pass `(False, True, True)`.
36
    save_memory : bool, optional
37
        Use an implementation that stores less intermediate state and
38
        therefore uses less memory, at the expense of 5-10% speed. Default
39
        is `True`.
40
    weights_init : object, optional
41
        Initialization object to use for the learned scaling parameter
42
        ($\\gamma$ in [BN]_). By default, uses constant initialization
43
        of 1.
44
    biases_init : object, optional
45
        Initialization object to use for the learned shift parameter
46
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
47
48
    Notes
49
    -----
50
    In order for trained models to behave sensibly immediately upon
51
    upon deserialization, by default, this brick runs in *inference* mode,
52
    using a population mean and population standard deviation (initialized
53
    to zeros and ones respectively) to normalize activations. It is
54
    expected that the user will adapt these during training in some
55
    fashion, independently of the training objective, e.g. by taking a
56
    moving average of minibatch-wise statistics.
57
58
    In order to *train* with batch normalization, one must obtain a
59
    training graph by transforming the original inference graph.  See
60
    :func:`batch_normalize`.
61
62
    This Brick accepts `weights_init` and `biases_init` arguments but is
63
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
64
    therefore not receive pushed initialization config from any parent
65
    brick. In almost all cases, you will probably want to stick with the
66
    defaults (unit scale and zero shift), but you can explicitly pass one
67
    or both initializers to override this.
68
69
    This has the necessary properties to be inserted into a
70
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
71
    the `input_dim` should be omitted at construction, to be inferred from
72
    the layer below.
73
74
    """
75
    @lazy(allocation=['input_dim'])
76
    def __init__(self, input_dim, broadcastable=None,
77
                 save_memory=True, weights_init=None,
78
                 biases_init=None, **kwargs):
79
        self.input_dim = input_dim
80
        self.broadcastable = broadcastable
81
        self.save_memory = save_memory
82
        self.weights_init = (Constant(1) if weights_init is None
83
                             else weights_init)
84
        self.biases_init = (Constant(0) if biases_init is None
85
                            else biases_init)
86
        super(BatchNormalization, self).__init__(**kwargs)
87
88
    @application(inputs=['input_'], outputs=['output'])
89
    def apply(self, input_, application_call):
90
        mean = self.population_mean.copy(name='population_offset')
91
        stdev = self.population_stdev.copy(name='population_divisor')
92
93
        def annotate(var, role):
94
            add_role(var, role)
95
            add_annotation(var, self)
96
            add_annotation(var, application_call)
97
98
        annotate(mean, BATCH_NORM_OFFSET)
99
        annotate(stdev, BATCH_NORM_DIVISOR)
100
101
        # Heavy lifting is done by the Theano utility function.
102
        normalized = bn.batch_normalization(input_, self.W,
103
                                            self.b, mean, stdev,
104
                                            mode=('low_mem' if self.save_memory
105
                                                  else 'high_mem'))
106
        return normalized
107
108
    def _allocate(self):
109
        input_dim = ((self.input_dim,)
110
                     if not isinstance(self.input_dim, collections.Sequence)
111
                     else self.input_dim)
112
        broadcastable = (tuple(False for _ in range(len(input_dim)))
113
                         if self.broadcastable is None else self.broadcastable)
114
        if len(input_dim) != len(broadcastable):
115
            raise ValueError("input_dim and broadcastable must be same length")
116
        var_dim = ((1,) +  # batch axis
117
                   tuple(1 if broadcast else dim for dim, broadcast in
118
                         equizip(input_dim, broadcastable)))
119
        broadcastable = (True,) + broadcastable
120
121
        # "gamma", from the Ioffe & Szegedy manuscript.
122
        self._W = shared_floatx_nans(var_dim, name='batch_norm_scale',
123
                                     broadcastable=broadcastable)
124
125
        # "beta", from the Ioffe & Szegedy manuscript.
126
        self._b = shared_floatx_nans(var_dim, name='batch_norm_shift',
127
                                     broadcastable=broadcastable)
128
        add_role(self.W, WEIGHT)
129
        add_role(self.b, BIAS)
130
        self.parameters.append(self.W)
131
        self.parameters.append(self.b)
132
133
        # These aren't technically parameters, in that they should not be
134
        # learned using the same cost function as other model parameters.
135
        self.population_mean = shared_floatx_zeros(var_dim,
136
                                                   name='population_mean',
137
                                                   broadcastable=broadcastable)
138
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
139
                                              name='population_stdev',
140
                                              broadcastable=broadcastable)
141
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
142
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
143
144
    @property
145
    def W(self):
146
        return self._W
147
148
    @property
149
    def b(self):
150
        return self._b
151
152
    def _initialize(self):
153
        self.biases_init.initialize(self.b, self.rng)
154
        self.weights_init.initialize(self.W, self.rng)
155
156
    # Needed for the Feedforward interface.
157
    @property
158
    def output_dim(self):
159
        return self.input_dim
160
161
    # The following properties allow for BatchNormalization bricks
162
    # to be used directly inside of a ConvolutionalSequence.
163
    @property
164
    def image_size(self):
165
        return self.input_dim[-2:]
166
167
    @image_size.setter
168
    def image_size(self, value):
169
        if not isinstance(self.input_dim, collections.Sequence):
170
            self.input_dim = (None,) + tuple(value)
171
        else:
172
            self.input_dim = (self.input_dim[0],) + tuple(value)
173
174
    @property
175
    def num_channels(self):
176
        return self.input_dim[0]
177
178
    @num_channels.setter
179
    def num_channels(self, value):
180
        if not isinstance(self.input_dim, collections.Sequence):
181
            self.input_dim = (value,) + (None, None)
182
        else:
183
            self.input_dim = (value,) + self.input_dim[-2:]
184
185
    def get_dim(self, name):
186
        if name in ('input', 'output'):
187
            return self.input_dim
188
        else:
189
            raise KeyError
190
191
    @property
192
    def num_output_channels(self):
193
        return self.num_channels
194
195
196
class SpatialBatchNormalization(BatchNormalization):
197
    """Convenient subclass for batch normalization across spatial inputs.
198
199
    Parameters
200
    ----------
201
    input_dim : int or tuple
202
        The input size of a single example. Must be length at least 2.
203
        It's assumed that the first axis of this tuple is a "channels"
204
        axis, which should not be summed over, and all remaining
205
        dimensions are spatial dimensions.
206
207
    Notes
208
    -----
209
    See :class:`BatchNormalization` for more details (and additional
210
    keyword arguments).
211
212
    """
213
    @lazy(allocation=['input_dim'])
214
    def __init__(self, input_dim, **kwargs):
215
        if not isinstance(input_dim,
216
                          collections.Sequence) or len(input_dim) < 2:
217
            raise ValueError('expected input_dim to be length >= 2 '
218
                             '(channels, height, width)')
219
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
220
        kwargs.setdefault('broadcastable', broadcastable)
221
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
222
223
224
class BatchNormalizedMLP(MLP):
225
    """Convenient subclass for building an MLP with batch normalization.
226
227
    Notes
228
    -----
229
    All parameters are the same as :class:`~blocks.bricks.MLP`. Each
230
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
231
    containing an appropriate :class:`BatchNormalization` brick and
232
    the activation that follows it.
233
234
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
235
    not contain any biases, as they could be canceled out by the biases
236
    in the :class:`BatchNormalization` bricks being added. Pass
237
    `use_bias` with a value of `True` if you really want this for some
238
    reason.
239
240
    """
241
    @lazy(allocation=['dims'])
242
    def __init__(self, activations, dims, *args, **kwargs):
243
        activations = [Sequence([BatchNormalization().apply, act.apply],
244
                                name='batch_norm_activation_{}'.format(i))
245
                       for i, act in enumerate(activations)]
246
        # Batch normalization bricks incorporate a bias, so there's no
247
        # need for our Linear bricks to have them.
248
        kwargs.setdefault('use_bias', False)
249
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
250
                                                 **kwargs)
251
252
    def _push_allocation_config(self):
253
        super(BatchNormalizedMLP, self)._push_allocation_config()
254
        # Do the extra allocation pushing for the BatchNormalization
255
        # bricks. They need as their input dimension the output dimension
256
        # of each linear transformation.  Exclude the first dimension,
257
        # which is the input dimension.
258
        for act, dim in equizip(self.activations, self.dims[1:]):
259
            act.children[0].input_dim = dim
260
261
262
def batch_normalize(computation_graph, epsilon=1e-4):
263
    """Activate batch normalization in a graph.
264
265
    Parameters
266
    ----------
267
    computation_graph : instance of :class:`ComputationGraph`
268
          The computation graph containing :class:`BatchNormalization`
269
          brick applications.
270
    epsilon : float, optional
271
        The stabilizing constant for the minibatch standard deviation
272
        computation. Added to the variance inside the square root, as
273
        in the batch normalization paper.
274
275
    Returns
276
    -------
277
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
278
          The computation graph, with :class:`BatchNormalization`
279
          applications transformed to use minibatch statistics instead
280
          of accumulated population statistics.
281
282
    Notes
283
    -----
284
    Assumes the minibatch axis is 0. Other axes are unsupported at
285
    this time.
286
287
    """
288
289
    # Create filters for variables involved in a batch normalization brick
290
    # application.
291
    def make_variable_filter(role):
292
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
293
294
    mean_filter, stdev_filter, input_filter = map(make_variable_filter,
295
                                                  [BATCH_NORM_OFFSET,
296
                                                   BATCH_NORM_DIVISOR, INPUT])
297
298
    # Group means, standard deviations, and inputs into dicts indexed by
299
    # application call.
300
    def get_application_call_dict(variable_filter):
301
        return collections.OrderedDict((get_application_call(v), v) for v in
302
                                       variable_filter(computation_graph))
303
304
    means, stdevs, inputs = map(get_application_call_dict,
305
                                [mean_filter, stdev_filter, input_filter])
306
307
    assert (set(means.keys()) == set(stdevs.keys()) and
308
            set(means.keys()) == set(inputs.keys()))
309
    assert set(means.values()).isdisjoint(stdevs.values())
310
311
    replacements = []
312
    # Perform replacement for each application call.
313
    for application_call in means:
314
        axes = tuple(i for i, b in enumerate(means[application_call]
315
                                             .broadcastable) if b)
316
        minibatch_mean = inputs[application_call].mean(axis=axes,
317
                                                       keepdims=True)
318
        minibatch_mean.name = 'minibatch_offset'
319
        # Stabilize in the same way as the batch normalization manuscript.
320
        minibatch_std = tensor.sqrt(tensor.var(inputs[application_call],
321
                                               axis=axes, keepdims=True)
322
                                    + epsilon)
323
        minibatch_std.name = 'minibatch_divisor'
324
325
        def prepare_replacement(old, new, role, application_call):
326
            """Add roles and tags to replaced variables."""
327
            add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE)
328
            add_role(new, role)
329
            add_annotation(new, application_call)
330
            add_annotation(new, application_call.application.brick)
331
            new.tag.replacement_of = old
332
            replacements.append((old, new))
333
334
        prepare_replacement(means[application_call], minibatch_mean,
335
                            BATCH_NORM_OFFSET, application_call)
336
        prepare_replacement(stdevs[application_call], minibatch_std,
337
                            BATCH_NORM_DIVISOR, application_call)
338
339
    return computation_graph.replace(replacements)
340
341
342
def population_to_minibatch(bn_graph):
343
    """Get a mapping from population statistics to minibatch estimates.
344
345
    Parameters
346
    ----------
347
    bn_graph : :class:`~blocks.graph.ComputationGraph`
348
        Graph returned by :func:`batch_normalize`.
349
350
    Returns
351
    -------
352
    OrderedDict
353
        A mapping from variables representing population statistics
354
        to the corresponding minibatch estimate that replaces it in
355
        the batch-normalized graph.
356
357
    """
358
    variables = VariableFilter(roles=[BATCH_NORM_MINIBATCH_ESTIMATE])(bn_graph)
359
    return collections.OrderedDict((v.replacement_of, v) for v in variables)
360