Completed
Pull Request — master (#1120)
by David
05:16
created

BatchNormalization.normalization_axes()   A

Complexity

Conditions 3

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 5
rs 9.4285
1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import shared_floatx, shared_floatx_nans, is_shared_variable
18
from .base import lazy, application
19
from .sequences import Sequence, Feedforward, MLP
20
from .interfaces import RNGMixin
21
22
23
def _add_batch_axis(var):
24
    """Prepend a singleton axis to a TensorVariable and name it."""
25
    new_var = new_var = tensor.shape_padleft(var)
26
    new_var.name = 'shape_padleft({})'.format(var.name)
27
    return new_var
28
29
30
def _add_role_and_annotate(var, role, annotations=()):
31
    """Add a role and zero or more annotations to a variable."""
32
    add_role(var, role)
33
    for annotation in annotations:
34
        add_annotation(var, annotation)
35
36
37
class BatchNormalization(RNGMixin, Feedforward):
38
    r"""Normalizes activations, parameterizes a scale and shift.
39
40
    Parameters
41
    ----------
42
    input_dim : int or tuple
43
        Shape of a single input example. It is assumed that a batch axis
44
        will be prepended to this.
45
    broadcastable : tuple, optional
46
        Tuple of the same length as `input_dim` which specifies which of
47
        the per-example axes should be averaged over to compute means and
48
        standard deviations. For example, in order to normalize over all
49
        spatial locations in a `(batch_index, channels, height, width)`
50
        image, pass `(False, True, True)`. The batch axis is always
51
        averaged out.
52
    conserve_memory : bool, optional
53
        Use an implementation that stores less intermediate state and
54
        therefore uses less memory, at the expense of 5-10% speed. Default
55
        is `True`.
56
    epsilon : float, optional
57
       The stabilizing constant for the minibatch standard deviation
58
       computation (when the brick is run in training mode).
59
       Added to the variance inside the square root, as in the
60
       batch normalization paper.
61
    scale_init : object, optional
62
        Initialization object to use for the learned scaling parameter
63
        ($\\gamma$ in [BN]_). By default, uses constant initialization
64
        of 1.
65
    shift_init : object, optional
66
        Initialization object to use for the learned shift parameter
67
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
68
    mean_only : bool, optional
69
        Perform "mean-only" batch normalization as described in [SK2016]_.
70
    learn_scale : bool, optional
71
        Whether to include a learned scale parameter ($\\gamma$ in [BN]_)
72
        in this brick. Default is `True`. Has no effect if `mean_only` is
73
        `True` (i.e. a scale parameter is never learned in mean-only mode).
74
    learn_shift : bool, optional
75
        Whether to include a learned shift parameter ($\\beta$ in [BN]_)
76
        in this brick. Default is `True`.
77
78
    Notes
79
    -----
80
    In order for trained models to behave sensibly immediately upon
81
    upon deserialization, by default, this brick runs in *inference* mode,
82
    using a population mean and population standard deviation (initialized
83
    to zeros and ones respectively) to normalize activations. It is
84
    expected that the user will adapt these during training in some
85
    fashion, independently of the training objective, e.g. by taking a
86
    moving average of minibatch-wise statistics.
87
88
    In order to *train* with batch normalization, one must obtain a
89
    training graph by transforming the original inference graph. See
90
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
91
    transform graphs, and :func:`~blocks.graph.batch_normalization`
92
    for a context manager that may enable shorter compile times
93
    (every instance of :class:`BatchNormalization` is itself a context
94
    manager, entry into which causes applications to be in minibatch
95
    "training" mode, however it is usually more convenient to use
96
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
97
    for all of your graph's :class:`BatchNormalization` bricks at once).
98
99
    Note that training in inference mode should be avoided, as this
100
    brick introduces scales and shift parameters (tagged with the
101
    `PARAMETER` role) that, in the absence of batch normalization,
102
    usually makes things unstable. If you must do this, filter for and
103
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
104
    from the list of parameters you are training, and this brick should
105
    behave as a (somewhat expensive) no-op.
106
107
    This Brick accepts `scale_init` and `shift_init` arguments but is
108
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
109
    therefore not receive pushed initialization config from any parent
110
    brick. In almost all cases, you will probably want to stick with the
111
    defaults (unit scale and zero offset), but you can explicitly pass one
112
    or both initializers to override this.
113
114
    This has the necessary properties to be inserted into a
115
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
116
    the `input_dim` should be omitted at construction, to be inferred from
117
    the layer below.
118
119
120
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
121
       accelerating deep network training by reducing internal covariate
122
       shift*. ICML (2015), pp. 448-456.
123
124
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
125
       normalization: a simple reparameterization to accelerate training
126
       of deep neural networks*. arXiv 1602.07868.
127
128
    """
129
    @lazy(allocation=['input_dim'])
130
    def __init__(self, input_dim, broadcastable=None,
131
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
132
                 shift_init=None, mean_only=False, learn_scale=True,
133
                 learn_shift=True, **kwargs):
134
        self.input_dim = input_dim
135
        self.broadcastable = broadcastable
136
        self.conserve_memory = conserve_memory
137
        self.epsilon = epsilon
138
        self.learn_scale = learn_scale
139
        self.learn_shift = learn_shift
140
        self.scale_init = (Constant(1) if scale_init is None
141
                           else scale_init)
142
        self.shift_init = (Constant(0) if shift_init is None
143
                           else shift_init)
144
        self.mean_only = mean_only
145
        self._training_mode = []
146
        super(BatchNormalization, self).__init__(**kwargs)
147
148
    @application(inputs=['input_'], outputs=['output'])
149
    def apply(self, input_, application_call):
150
        if self._training_mode:
151
            mean, stdev = self._compute_training_statistics(input_)
152
        else:
153
            mean, stdev = self._prepare_population_statistics()
154
        # Useful for filtration of calls that were already made in
155
        # training mode when doing graph transformations.
156
        # Very important to cast to bool, as self._training_mode is
157
        # normally a list (to support nested context managers), which would
158
        # otherwise get passed by reference and be remotely mutated.
159
        application_call.metadata['training_mode'] = bool(self._training_mode)
160
        # Useful for retrieving a list of updates for population
161
        # statistics. Ditch the broadcastable first axis, though, to
162
        # make it the same dimensions as the population mean and stdev
163
        # shared variables.
164
        application_call.metadata['offset'] = mean[0]
165
        application_call.metadata['divisor'] = stdev[0]
166
        # Give these quantities roles in the graph.
167
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
168
                               [self, application_call])
169
        if self.mean_only:
170
            stdev = tensor.ones_like(mean)
171
        else:
172
            # The annotation/role information is useless if it's a constant.
173
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
174
                                   [self, application_call])
175
        shift = _add_batch_axis(self.shift)
176
        scale = _add_batch_axis(self.scale)
177
        # Heavy lifting is done by the Theano utility function.
178
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
179
                                            mode=('low_mem'
180
                                                  if self.conserve_memory
181
                                                  else 'high_mem'))
182
        return normalized
183
184
    def __enter__(self):
185
        self._training_mode.append(True)
186
187
    def __exit__(self, *exc_info):
188
        self._training_mode.pop()
189
190
    def _compute_training_statistics(self, input_):
191
        axes = self.normalization_axes
192
        mean = input_.mean(axis=axes, keepdims=True)
193
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
194
        if self.mean_only:
195
            stdev = tensor.ones_like(mean)
196
        else:
197
            var = (tensor.sqr(input_).mean(axis=axes, keepdims=True) -
198
                   tensor.sqr(mean))
199
            eps = numpy.cast[theano.config.floatX](self.epsilon)
200
            stdev = tensor.sqrt(var + eps)
201
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
202
        return mean, stdev
203
204
    @property
205
    def normalization_axes(self):
206
        return (0,) + tuple((i + 1) for i, b in
207
                            enumerate(self.population_mean.broadcastable)
208
                            if b)
209
210
    def _prepare_population_statistics(self):
211
        mean = _add_batch_axis(self.population_mean)
212
        if self.mean_only:
213
            stdev = tensor.ones_like(self.population_mean)
214
        else:
215
            stdev = self.population_stdev
216
        stdev = _add_batch_axis(stdev)
217
        return mean, stdev
218
219
    def _allocate(self):
220
        input_dim = ((self.input_dim,)
221
                     if not isinstance(self.input_dim, collections.Sequence)
222
                     else self.input_dim)
223
        broadcastable = (tuple(False for _ in input_dim)
224
                         if self.broadcastable is None else self.broadcastable)
225
        if len(input_dim) != len(broadcastable):
226
            raise ValueError("input_dim and broadcastable must be same length")
227
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
228
                        equizip(input_dim, broadcastable))
229
        broadcastable = broadcastable
230
231
        # "beta", from the Ioffe & Szegedy manuscript.
232
        if self.learn_shift:
233
            self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
234
                                            broadcastable=broadcastable)
235
            add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
236
            self.parameters.append(self.shift)
237
        else:
238
            self.shift = tensor.constant(0, dtype=theano.config.floatX)
239
240
        if self.learn_scale and not self.mean_only:
241
            # "gamma", from the Ioffe & Szegedy manuscript.
242
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
243
                                            broadcastable=broadcastable)
244
245
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
246
            self.parameters.append(self.scale)
247
        else:
248
            self.scale = tensor.constant(1., dtype=theano.config.floatX)
249
250
        self._allocate_population_statistics(var_dim, broadcastable)
251
252
    def _allocate_population_statistics(self, var_dim, broadcastable):
253
        def _allocate_buffer(name, role, value):
254
            # These aren't technically parameters, in that they should not be
255
            # learned using the same cost function as other model parameters.
256
            population_buffer = shared_floatx(value * numpy.ones(var_dim),
257
                                              broadcastable=broadcastable,
258
                                              name=name)
259
            add_role(population_buffer, role)
260
            # Normally these would get annotated by an AnnotatingList, but they
261
            # aren't in self.parameters.
262
            add_annotation(population_buffer, self)
263
            return population_buffer
264
265
        self.population_mean = _allocate_buffer('population_mean',
266
                                                BATCH_NORM_POPULATION_MEAN,
267
                                                numpy.zeros(var_dim))
268
269
        self.population_stdev = _allocate_buffer('population_stdev',
270
                                                 BATCH_NORM_POPULATION_STDEV,
271
                                                 numpy.ones(var_dim))
272
273
    def _initialize(self):
274
        # We gate with is_shared_variable rather than relying on
275
        # learn_scale and learn_shift so as to avoid the unlikely but nasty
276
        # scenario where those flags are changed post-allocation but
277
        # pre-initialization. This ensures that such a change simply has no
278
        # effect rather than doing an inconsistent combination of things.
279
        if is_shared_variable(self.shift):
280
            self.shift_init.initialize(self.shift, self.rng)
281
        if is_shared_variable(self.scale):
282
            self.scale_init.initialize(self.scale, self.rng)
283
284
    # Needed for the Feedforward interface.
285
    @property
286
    def output_dim(self):
287
        return self.input_dim
288
289
    # The following properties allow for BatchNormalization bricks
290
    # to be used directly inside of a ConvolutionalSequence.
291
    @property
292
    def image_size(self):
293
        return self.input_dim[-2:]
294
295
    @image_size.setter
296
    def image_size(self, value):
297
        if not isinstance(self.input_dim, collections.Sequence):
298
            self.input_dim = (None,) + tuple(value)
299
        else:
300
            self.input_dim = (self.input_dim[0],) + tuple(value)
301
302
    @property
303
    def num_channels(self):
304
        return self.input_dim[0]
305
306
    @num_channels.setter
307
    def num_channels(self, value):
308
        if not isinstance(self.input_dim, collections.Sequence):
309
            self.input_dim = (value,) + (None, None)
310
        else:
311
            self.input_dim = (value,) + self.input_dim[-2:]
312
313
    def get_dim(self, name):
314
        if name in ('input', 'output'):
315
            return self.input_dim
316
        else:
317
            raise KeyError
318
319
    @property
320
    def num_output_channels(self):
321
        return self.num_channels
322
323
324
class SpatialBatchNormalization(BatchNormalization):
325
    """Convenient subclass for batch normalization across spatial inputs.
326
327
    Parameters
328
    ----------
329
    input_dim : int or tuple
330
        The input size of a single example. Must be length at least 2.
331
        It's assumed that the first axis of this tuple is a "channels"
332
        axis, which should not be summed over, and all remaining
333
        dimensions are spatial dimensions.
334
335
    Notes
336
    -----
337
    See :class:`BatchNormalization` for more details (and additional
338
    keyword arguments).
339
340
    """
341
    def _allocate(self):
342
        if not isinstance(self.input_dim,
343
                          collections.Sequence) or len(self.input_dim) < 2:
344
            raise ValueError('expected input_dim to be length >= 2 '
345
                             'e.g. (channels, height, width)')
346
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
347
        super(SpatialBatchNormalization, self)._allocate()
348
349
350
class BatchNormalizedMLP(MLP):
351
    """Convenient subclass for building an MLP with batch normalization.
352
353
    Parameters
354
    ----------
355
    conserve_memory : bool, optional, by keyword only
356
        See :class:`BatchNormalization`.
357
    mean_only : bool, optional, by keyword only
358
        See :class:`BatchNormalization`.
359
    learn_scale : bool, optional, by keyword only
360
        See :class:`BatchNormalization`.
361
    learn_shift : bool, optional, by keyword only
362
        See :class:`BatchNormalization`.
363
364
    Notes
365
    -----
366
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
367
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
368
    containing an appropriate :class:`BatchNormalization` brick and
369
    the activation that follows it.
370
371
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
372
    not contain any biases, as they could be canceled out by the biases
373
    in the :class:`BatchNormalization` bricks being added. Pass
374
    `use_bias` with a value of `True` if you really want this for some
375
    reason.
376
377
    `mean_only`, `learn_scale` and `learn_shift` are pushed down to
378
    all created :class:`BatchNormalization` bricks as allocation
379
    config.
380
381
    """
382
    @lazy(allocation=['dims'])
383
    def __init__(self, activations, dims, *args, **kwargs):
384
        self._conserve_memory = kwargs.pop('conserve_memory', True)
385
        self.mean_only = kwargs.pop('mean_only', False)
386
        self.learn_scale = kwargs.pop('learn_scale', True)
387
        self.learn_shift = kwargs.pop('learn_shift', True)
388
389
        activations = [
390
            Sequence([
391
                (BatchNormalization(conserve_memory=self._conserve_memory)
392
                 .apply),
393
                act.apply
394
            ], name='batch_norm_activation_{}'.format(i))
395
            for i, act in enumerate(activations)
396
        ]
397
        # Batch normalization bricks incorporate a bias, so there's no
398
        # need for our Linear bricks to have them.
399
        kwargs.setdefault('use_bias', False)
400
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
401
                                                 **kwargs)
402
403
    def _nested_brick_property_getter(self, property_name):
404
        return getattr(self, '_' + property_name)
405
406
    def _nested_brick_property_setter(self, value, property_name):
407
        setattr(self, '_' + property_name, value)
408
        for act in self.activations:
409
            assert isinstance(act.children[0], BatchNormalization)
410
            setattr(act.children[0], property_name, value)
411
412
    # conserve_memory is a bit special in that it can be modified
413
    # after construction/allocation and still have a valid effect on
414
    # apply(). Thus we propagate down all property sets.
415
    conserve_memory = property(partial(_nested_brick_property_getter,
416
                                       property_name='conserve_memory'),
417
                               partial(_nested_brick_property_setter,
418
                                       property_name='conserve_memory'),
419
                               doc="Conserve memory.")
420
421
    def _push_allocation_config(self):
422
        super(BatchNormalizedMLP, self)._push_allocation_config()
423
        # Do the extra allocation pushing for the BatchNormalization
424
        # bricks. They need as their input dimension the output dimension
425
        # of each linear transformation.  Exclude the first dimension,
426
        # which is the input dimension.
427
        for act, dim in equizip(self.activations, self.dims[1:]):
428
            assert isinstance(act.children[0], BatchNormalization)
429
            act.children[0].input_dim = dim
430
            for attr in ['mean_only', 'learn_scale', 'learn_shift']:
431
                setattr(act.children[0], attr, getattr(self, attr))
432