Completed
Pull Request — master (#1064)
by Dmitry
04:46
created

BatchNormalization._compute_training_statistics()   B

Complexity

Conditions 6

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 6
dl 0
loc 16
rs 8
1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import (shared_floatx_zeros, shared_floatx,
18
                     shared_floatx_nans)
19
from .base import lazy, application
20
from .sequences import Sequence, Feedforward, MLP
21
from .interfaces import RNGMixin
22
23
24
def _add_batch_axis(var):
25
    """Prepend a singleton axis to a TensorVariable and name it."""
26
    new_var = new_var = tensor.shape_padleft(var)
27
    new_var.name = 'shape_padleft({})'.format(var.name)
28
    return new_var
29
30
31
def _add_role_and_annotate(var, role, annotations=()):
32
    """Add a role and zero or more annotations to a variable."""
33
    add_role(var, role)
34
    for annotation in annotations:
35
        add_annotation(var, annotation)
36
37
38
class BatchNormalization(RNGMixin, Feedforward):
39
    r"""Normalizes activations, parameterizes a scale and shift.
40
41
    Parameters
42
    ----------
43
    input_dim : int or tuple
44
        Shape of a single input example. It is assumed that a batch axis
45
        will be prepended to this.
46
    broadcastable : tuple, optional
47
        Tuple of the same length as `input_dim` which specifies which of
48
        the per-example axes should be averaged over to compute means and
49
        standard deviations. For example, in order to normalize over all
50
        spatial locations in a `(batch_index, channels, height, width)`
51
        image, pass `(False, True, True)`. The batch axis is always
52
        averaged out.
53
    conserve_memory : bool, optional
54
        Use an implementation that stores less intermediate state and
55
        therefore uses less memory, at the expense of 5-10% speed. Default
56
        is `True`.
57
    epsilon : float, optional
58
       The stabilizing constant for the minibatch standard deviation
59
       computation (when the brick is run in training mode).
60
       Added to the variance inside the square root, as in the
61
       batch normalization paper.
62
    scale_init : object, optional
63
        Initialization object to use for the learned scaling parameter
64
        ($\\gamma$ in [BN]_). By default, uses constant initialization
65
        of 1.
66
    shift_init : object, optional
67
        Initialization object to use for the learned shift parameter
68
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
69
    mean_only : bool, optional
70
        Perform "mean-only" batch normalization as described in [SK2016]_.
71
72
    Notes
73
    -----
74
    In order for trained models to behave sensibly immediately upon
75
    upon deserialization, by default, this brick runs in *inference* mode,
76
    using a population mean and population standard deviation (initialized
77
    to zeros and ones respectively) to normalize activations. It is
78
    expected that the user will adapt these during training in some
79
    fashion, independently of the training objective, e.g. by taking a
80
    moving average of minibatch-wise statistics.
81
82
    In order to *train* with batch normalization, one must obtain a
83
    training graph by transforming the original inference graph. See
84
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
85
    transform graphs, and :func:`~blocks.graph.batch_normalization`
86
    for a context manager that may enable shorter compile times
87
    (every instance of :class:`BatchNormalization` is itself a context
88
    manager, entry into which causes applications to be in minibatch
89
    "training" mode, however it is usually more convenient to use
90
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
91
    for all of your graph's :class:`BatchNormalization` bricks at once).
92
93
    Note that training in inference mode should be avoided, as this
94
    brick introduces scales and shift parameters (tagged with the
95
    `PARAMETER` role) that, in the absence of batch normalization,
96
    usually makes things unstable. If you must do this, filter for and
97
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
98
    from the list of parameters you are training, and this brick should
99
    behave as a (somewhat expensive) no-op.
100
101
    This Brick accepts `scale_init` and `shift_init` arguments but is
102
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
103
    therefore not receive pushed initialization config from any parent
104
    brick. In almost all cases, you will probably want to stick with the
105
    defaults (unit scale and zero offset), but you can explicitly pass one
106
    or both initializers to override this.
107
108
    This has the necessary properties to be inserted into a
109
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
110
    the `input_dim` should be omitted at construction, to be inferred from
111
    the layer below.
112
113
114
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
115
       accelerating deep network training by reducing internal covariate
116
       shift*. ICML (2015), pp. 448-456.
117
118
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
119
       normalization: a simple reparameterization to accelerate training
120
       of deep neural networks*. arXiv 1602.07868.
121
122
    """
123
    @lazy(allocation=['input_dim'])
124
    def __init__(self, input_dim, broadcastable=None,
125
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
126
                 shift_init=None, mean_only=False, **kwargs):
127
        self.input_dim = input_dim
128
        self.broadcastable = broadcastable
129
        self.conserve_memory = conserve_memory
130
        self.epsilon = epsilon
131
        self.scale_init = (Constant(1) if scale_init is None
132
                           else scale_init)
133
        self.shift_init = (Constant(0) if shift_init is None
134
                           else shift_init)
135
        self.mean_only = mean_only
136
        self._training_mode = []
137
        super(BatchNormalization, self).__init__(**kwargs)
138
139
    @application(inputs=['input_'], outputs=['output'])
140
    def apply(self, input_, application_call):
141
        if self._training_mode:
142
            mean, stdev = self._compute_training_statistics(input_)
143
        else:
144
            mean, stdev = self._prepare_population_statistics()
145
        # Useful for filtration of calls that were already made in
146
        # training mode when doing graph transformations.
147
        # Very important to cast to bool, as self._training_mode is
148
        # normally a list (to support nested context managers), which would
149
        # otherwise get passed by reference and be remotely mutated.
150
        application_call.metadata['training_mode'] = bool(self._training_mode)
151
        # Useful for retrieving a list of updates for population
152
        # statistics. Ditch the broadcastable first axis, though, to
153
        # make it the same dimensions as the population mean and stdev
154
        # shared variables.
155
        application_call.metadata['offset'] = mean[0]
156
        application_call.metadata['divisor'] = stdev[0]
157
        # Give these quantities roles in the graph.
158
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
159
                               [self, application_call])
160
        if self.mean_only:
161
            scale = tensor.ones_like(self.shift)
162
            stdev = tensor.ones_like(mean)
163
        else:
164
            scale = self.scale
165
            # The annotation/role information is useless if it's a constant.
166
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
167
                                   [self, application_call])
168
        shift = _add_batch_axis(self.shift)
169
        scale = _add_batch_axis(scale)
170
        # Heavy lifting is done by the Theano utility function.
171
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
172
                                            mode=('low_mem'
173
                                                  if self.conserve_memory
174
                                                  else 'high_mem'))
175
        return normalized
176
177
    def __enter__(self):
178
        self._training_mode.append(True)
179
180
    def __exit__(self, *exc_info):
181
        self._training_mode.pop()
182
183
    def _compute_training_statistics(self, input_):
184
        axes = (0,) + tuple((i + 1) for i, b in
185
                            enumerate(self.population_mean.broadcastable)
186
                            if b)
187
        mean = input_.mean(axis=axes, keepdims=True)
188
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
189
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
190
        if self.mean_only:
191
            stdev = tensor.ones_like(mean)
192
        else:
193
            stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
194
                                numpy.cast[theano.config.floatX](self.epsilon))
195
            assert (stdev.broadcastable[1:] ==
196
                    self.population_stdev.broadcastable)
197
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
198
        return mean, stdev
199
200
    def _prepare_population_statistics(self):
201
        mean = _add_batch_axis(self.population_mean)
202
        if self.mean_only:
203
            stdev = tensor.ones_like(self.population_mean)
204
        else:
205
            stdev = self.population_stdev
206
        stdev = _add_batch_axis(stdev)
207
        return mean, stdev
208
209
    def _allocate(self):
210
        input_dim = ((self.input_dim,)
211
                     if not isinstance(self.input_dim, collections.Sequence)
212
                     else self.input_dim)
213
        broadcastable = (tuple(False for _ in input_dim)
214
                         if self.broadcastable is None else self.broadcastable)
215
        if len(input_dim) != len(broadcastable):
216
            raise ValueError("input_dim and broadcastable must be same length")
217
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
218
                        equizip(input_dim, broadcastable))
219
        broadcastable = broadcastable
220
221
        # "beta", from the Ioffe & Szegedy manuscript.
222
        self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
223
                                        broadcastable=broadcastable)
224
        add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
225
        self.parameters.append(self.shift)
226
227
        # These aren't technically parameters, in that they should not be
228
        # learned using the same cost function as other model parameters.
229
        self.population_mean = shared_floatx_zeros(var_dim,
230
                                                   name='population_mean',
231
                                                   broadcastable=broadcastable)
232
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
233
234
        # Normally these would get annotated by an AnnotatingList, but they
235
        # aren't in self.parameters.
236
        add_annotation(self.population_mean, self)
237
238
        if not self.mean_only:
239
            # "gamma", from the Ioffe & Szegedy manuscript.
240
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
241
                                            broadcastable=broadcastable)
242
243
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
244
            self.parameters.append(self.scale)
245
246
            self.population_stdev = shared_floatx(numpy.ones(var_dim),
247
                                                  name='population_stdev',
248
                                                  broadcastable=broadcastable)
249
            add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
250
            add_annotation(self.population_stdev, self)
251
252
    def _initialize(self):
253
        self.shift_init.initialize(self.shift, self.rng)
254
        if not self.mean_only:
255
            self.scale_init.initialize(self.scale, self.rng)
256
257
    # Needed for the Feedforward interface.
258
    @property
259
    def output_dim(self):
260
        return self.input_dim
261
262
    # The following properties allow for BatchNormalization bricks
263
    # to be used directly inside of a ConvolutionalSequence.
264
    @property
265
    def image_size(self):
266
        return self.input_dim[-2:]
267
268
    @image_size.setter
269
    def image_size(self, value):
270
        if not isinstance(self.input_dim, collections.Sequence):
271
            self.input_dim = (None,) + tuple(value)
272
        else:
273
            self.input_dim = (self.input_dim[0],) + tuple(value)
274
275
    @property
276
    def num_channels(self):
277
        return self.input_dim[0]
278
279
    @num_channels.setter
280
    def num_channels(self, value):
281
        if not isinstance(self.input_dim, collections.Sequence):
282
            self.input_dim = (value,) + (None, None)
283
        else:
284
            self.input_dim = (value,) + self.input_dim[-2:]
285
286
    def get_dim(self, name):
287
        if name in ('input', 'output'):
288
            return self.input_dim
289
        else:
290
            raise KeyError
291
292
    @property
293
    def num_output_channels(self):
294
        return self.num_channels
295
296
297
class SpatialBatchNormalization(BatchNormalization):
298
    """Convenient subclass for batch normalization across spatial inputs.
299
300
    Parameters
301
    ----------
302
    input_dim : int or tuple
303
        The input size of a single example. Must be length at least 2.
304
        It's assumed that the first axis of this tuple is a "channels"
305
        axis, which should not be summed over, and all remaining
306
        dimensions are spatial dimensions.
307
308
    Notes
309
    -----
310
    See :class:`BatchNormalization` for more details (and additional
311
    keyword arguments).
312
313
    """
314
    def _allocate(self):
315
        if not isinstance(self.input_dim,
316
                          collections.Sequence) or len(self.input_dim) < 2:
317
            raise ValueError('expected input_dim to be length >= 2 '
318
                             'e.g. (channels, height, width)')
319
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
320
        super(SpatialBatchNormalization, self)._allocate()
321
322
323
class BatchNormalizedMLP(MLP):
324
    """Convenient subclass for building an MLP with batch normalization.
325
326
    Parameters
327
    ----------
328
    conserve_memory : bool, optional
329
        See :class:`BatchNormalization`.
330
    mean_only : bool, optional
331
        See :class:`BatchNormalization`.
332
333
    Notes
334
    -----
335
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
336
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
337
    containing an appropriate :class:`BatchNormalization` brick and
338
    the activation that follows it.
339
340
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
341
    not contain any biases, as they could be canceled out by the biases
342
    in the :class:`BatchNormalization` bricks being added. Pass
343
    `use_bias` with a value of `True` if you really want this for some
344
    reason.
345
346
    """
347
    @lazy(allocation=['dims'])
348
    def __init__(self, activations, dims, *args, **kwargs):
349
        self._conserve_memory = kwargs.pop('conserve_memory', True)
350
        self._mean_only = kwargs.pop('mean_only', False)
351
        activations = [
352
            Sequence([
353
                BatchNormalization(conserve_memory=self._conserve_memory,
354
                                   mean_only=self._mean_only).apply,
355
                act.apply
356
            ], name='batch_norm_activation_{}'.format(i))
357
            for i, act in enumerate(activations)
358
        ]
359
        # Batch normalization bricks incorporate a bias, so there's no
360
        # need for our Linear bricks to have them.
361
        kwargs.setdefault('use_bias', False)
362
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
363
                                                 **kwargs)
364
365
    def _nested_brick_property_getter(self, property_name):
366
        return getattr(self, '_' + property_name)
367
368
    def _nested_brick_property_setter(self, value, property_name):
369
        setattr(self, '_' + property_name, value)
370
        for act in self.activations:
371
            assert isinstance(act.children[0], BatchNormalization)
372
            setattr(act.children[0], property_name, value)
373
374
    conserve_memory = property(partial(_nested_brick_property_getter,
375
                                       property_name='conserve_memory'),
376
                               partial(_nested_brick_property_setter,
377
                                       property_name='conserve_memory'))
378
379
    mean_only = property(partial(_nested_brick_property_getter,
380
                                 property_name='mean_only'),
381
                         partial(_nested_brick_property_setter,
382
                                 property_name='mean_only'))
383
384
    def _push_allocation_config(self):
385
        super(BatchNormalizedMLP, self)._push_allocation_config()
386
        # Do the extra allocation pushing for the BatchNormalization
387
        # bricks. They need as their input dimension the output dimension
388
        # of each linear transformation.  Exclude the first dimension,
389
        # which is the input dimension.
390
        for act, dim in equizip(self.activations, self.dims[1:]):
391
            assert isinstance(act.children[0], BatchNormalization)
392
            act.children[0].input_dim = dim
393