Completed
Pull Request — master (#1062)
by David
04:35
created

BatchNormalization.num_output_channels()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import (shared_floatx_zeros, shared_floatx,
18
                     shared_floatx_nans)
19
from .base import lazy, application
20
from .sequences import Sequence, Feedforward, MLP
21
from .interfaces import RNGMixin
22
23
24
def _add_batch_axis(var):
25
    """Prepend a singleton axis to a TensorVariable and name it."""
26
    new_var = new_var = tensor.shape_padleft(var)
27
    new_var.name = 'shape_padleft({})'.format(var.name)
28
    return new_var
29
30
31
def _add_role_and_annotate(var, role, annotations=()):
32
    """Add a role and zero or more annotations to a variable."""
33
    add_role(var, role)
34
    for annotation in annotations:
35
        add_annotation(var, annotation)
36
37
38
class BatchNormalization(RNGMixin, Feedforward):
39
    r"""Normalizes activations, parameterizes a scale and shift.
40
41
    Parameters
42
    ----------
43
    input_dim : int or tuple
44
        Shape of a single input example. It is assumed that a batch axis
45
        will be prepended to this.
46
    broadcastable : tuple, optional
47
        Tuple the same length as `input_dim` which specifies which of the
48
        per-example axes should be averaged over to compute means and
49
        standard deviations. For example, in order to normalize over all
50
        spatial locations in a `(batch_index, channels, height, width)`
51
        image, pass `(False, True, True)`.
52
    conserve_memory : bool, optional
53
        Use an implementation that stores less intermediate state and
54
        therefore uses less memory, at the expense of 5-10% speed. Default
55
        is `True`.
56
    epsilon : float, optional
57
       The stabilizing constant for the minibatch standard deviation
58
       computation (when the brick is run in training mode).
59
       Added to the variance inside the square root, as in the
60
       batch normalization paper.
61
    scale_init : object, optional
62
        Initialization object to use for the learned scaling parameter
63
        ($\\gamma$ in [BN]_). By default, uses constant initialization
64
        of 1.
65
    shift_init : object, optional
66
        Initialization object to use for the learned shift parameter
67
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
68
    mean_only : bool, optional
69
        Perform "mean-only" batch normalization as described in [SK2016]_.
70
71
    Notes
72
    -----
73
    In order for trained models to behave sensibly immediately upon
74
    upon deserialization, by default, this brick runs in *inference* mode,
75
    using a population mean and population standard deviation (initialized
76
    to zeros and ones respectively) to normalize activations. It is
77
    expected that the user will adapt these during training in some
78
    fashion, independently of the training objective, e.g. by taking a
79
    moving average of minibatch-wise statistics.
80
81
    In order to *train* with batch normalization, one must obtain a
82
    training graph by transforming the original inference graph. See
83
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
84
    transform graphs, and :func:`~blocks.graph.batch_normalization`
85
    for a context manager that may enable shorter compile times
86
    (every instance of :class:`BatchNormalization` is itself a context
87
    manager, entry into which causes applications to be in minibatch
88
    "training" mode, however it is usually more convenient to use
89
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
90
    for all of your graph's :class:`BatchNormalization` bricks at once).
91
92
    Note that training in inference mode should be avoided, as this
93
    brick introduces scales and shift parameters (tagged with the
94
    `PARAMETER` role) that, in the absence of batch normalization,
95
    usually makes things unstable. If you must do this, filter for and
96
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
97
    from the list of parameters you are training, and this brick should
98
    behave as a (somewhat expensive) no-op.
99
100
    This Brick accepts `scale_init` and `shift_init` arguments but is
101
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
102
    therefore not receive pushed initialization config from any parent
103
    brick. In almost all cases, you will probably want to stick with the
104
    defaults (unit scale and zero offset), but you can explicitly pass one
105
    or both initializers to override this.
106
107
    This has the necessary properties to be inserted into a
108
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
109
    the `input_dim` should be omitted at construction, to be inferred from
110
    the layer below.
111
112
113
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
114
       accelerating deep network training by reducing internal covariate
115
       shift*. ICML (2015), pp. 448-456.
116
117
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
118
       normalization: a simple reparameterization to accelerate training
119
       of deep neural networks*. arXiv 1602.07868.
120
121
    """
122
    @lazy(allocation=['input_dim'])
123
    def __init__(self, input_dim, broadcastable=None,
124
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
125
                 shift_init=None, mean_only=False, **kwargs):
126
        self.input_dim = input_dim
127
        self.broadcastable = broadcastable
128
        self.conserve_memory = conserve_memory
129
        self.epsilon = epsilon
130
        self.scale_init = (Constant(1) if scale_init is None
131
                           else scale_init)
132
        self.shift_init = (Constant(0) if shift_init is None
133
                           else shift_init)
134
        self.mean_only = mean_only
135
        self._training_mode = []
136
        super(BatchNormalization, self).__init__(**kwargs)
137
138
    @application(inputs=['input_'], outputs=['output'])
139
    def apply(self, input_, application_call):
140
        if self._training_mode:
141
            mean, stdev = self._compute_training_statistics(input_)
142
        else:
143
            mean, stdev = self._prepare_population_statistics()
144
        # Useful for filtration of calls that were already made in
145
        # training mode when doing graph transformations.
146
        # Very important to cast to bool, as self._training_mode is
147
        # normally a list (to support nested context managers), which would
148
        # otherwise get passed by reference and be remotely mutated.
149
        application_call.metadata['training_mode'] = bool(self._training_mode)
150
        # Useful for retrieving a list of updates for population
151
        # statistics. Ditch the broadcastable first axis, though, to
152
        # make it the same dimensions as the population mean and stdev
153
        # shared variables.
154
        application_call.metadata['offset'] = mean[0]
155
        application_call.metadata['divisor'] = stdev[0]
156
        # Give these quantities roles in the graph.
157
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
158
                               [self, application_call])
159
        if self.mean_only:
160
            scale = tensor.ones_like(self.shift)
161
            stdev = tensor.ones_like(mean)
162
        else:
163
            scale = self.scale
164
            # The annotation/role information is useless if it's a constant.
165
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
166
                                   [self, application_call])
167
        shift = _add_batch_axis(self.shift)
168
        scale = _add_batch_axis(scale)
169
        # Heavy lifting is done by the Theano utility function.
170
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
171
                                            mode=('low_mem'
172
                                                  if self.conserve_memory
173
                                                  else 'high_mem'))
174
        return normalized
175
176
    def __enter__(self):
177
        self._training_mode.append(True)
178
179
    def __exit__(self, *exc_info):
180
        self._training_mode.pop()
181
182
    def _compute_training_statistics(self, input_):
183
        axes = (0,) + tuple((i + 1) for i, b in
184
                            enumerate(self.population_mean.broadcastable)
185
                            if b)
186
        mean = input_.mean(axis=axes, keepdims=True)
187
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
188
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
189
        if self.mean_only:
190
            stdev = tensor.ones_like(mean)
191
        else:
192
            var = (tensor.sqr(input_).mean(axis=axes, keepdims=True) -
193
                   tensor.sqr(mean))
194
            eps = numpy.cast[theano.config.floatX](self.epsilon)
195
            stdev = tensor.sqrt(var + eps)
196
            assert (stdev.broadcastable[1:] ==
197
                    self.population_stdev.broadcastable)
198
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
199
        return mean, stdev
200
201
    def _prepare_population_statistics(self):
202
        mean = _add_batch_axis(self.population_mean)
203
        if self.mean_only:
204
            stdev = tensor.ones_like(self.population_mean)
205
        else:
206
            stdev = self.population_stdev
207
        stdev = _add_batch_axis(stdev)
208
        return mean, stdev
209
210
    def _allocate(self):
211
        input_dim = ((self.input_dim,)
212
                     if not isinstance(self.input_dim, collections.Sequence)
213
                     else self.input_dim)
214
        broadcastable = (tuple(False for _ in input_dim)
215
                         if self.broadcastable is None else self.broadcastable)
216
        if len(input_dim) != len(broadcastable):
217
            raise ValueError("input_dim and broadcastable must be same length")
218
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
219
                        equizip(input_dim, broadcastable))
220
        broadcastable = broadcastable
221
222
        # "beta", from the Ioffe & Szegedy manuscript.
223
        self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
224
                                        broadcastable=broadcastable)
225
        add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
226
        self.parameters.append(self.shift)
227
228
        # These aren't technically parameters, in that they should not be
229
        # learned using the same cost function as other model parameters.
230
        self.population_mean = shared_floatx_zeros(var_dim,
231
                                                   name='population_mean',
232
                                                   broadcastable=broadcastable)
233
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
234
235
        # Normally these would get annotated by an AnnotatingList, but they
236
        # aren't in self.parameters.
237
        add_annotation(self.population_mean, self)
238
239
        if not self.mean_only:
240
            # "gamma", from the Ioffe & Szegedy manuscript.
241
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
242
                                            broadcastable=broadcastable)
243
244
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
245
            self.parameters.append(self.scale)
246
247
            self.population_stdev = shared_floatx(numpy.ones(var_dim),
248
                                                  name='population_stdev',
249
                                                  broadcastable=broadcastable)
250
            add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
251
            add_annotation(self.population_stdev, self)
252
253
    def _initialize(self):
254
        self.shift_init.initialize(self.shift, self.rng)
255
        if not self.mean_only:
256
            self.scale_init.initialize(self.scale, self.rng)
257
258
    # Needed for the Feedforward interface.
259
    @property
260
    def output_dim(self):
261
        return self.input_dim
262
263
    # The following properties allow for BatchNormalization bricks
264
    # to be used directly inside of a ConvolutionalSequence.
265
    @property
266
    def image_size(self):
267
        return self.input_dim[-2:]
268
269
    @image_size.setter
270
    def image_size(self, value):
271
        if not isinstance(self.input_dim, collections.Sequence):
272
            self.input_dim = (None,) + tuple(value)
273
        else:
274
            self.input_dim = (self.input_dim[0],) + tuple(value)
275
276
    @property
277
    def num_channels(self):
278
        return self.input_dim[0]
279
280
    @num_channels.setter
281
    def num_channels(self, value):
282
        if not isinstance(self.input_dim, collections.Sequence):
283
            self.input_dim = (value,) + (None, None)
284
        else:
285
            self.input_dim = (value,) + self.input_dim[-2:]
286
287
    def get_dim(self, name):
288
        if name in ('input', 'output'):
289
            return self.input_dim
290
        else:
291
            raise KeyError
292
293
    @property
294
    def num_output_channels(self):
295
        return self.num_channels
296
297
298
class SpatialBatchNormalization(BatchNormalization):
299
    """Convenient subclass for batch normalization across spatial inputs.
300
301
    Parameters
302
    ----------
303
    input_dim : int or tuple
304
        The input size of a single example. Must be length at least 2.
305
        It's assumed that the first axis of this tuple is a "channels"
306
        axis, which should not be summed over, and all remaining
307
        dimensions are spatial dimensions.
308
309
    Notes
310
    -----
311
    See :class:`BatchNormalization` for more details (and additional
312
    keyword arguments).
313
314
    """
315
    def _allocate(self):
316
        if not isinstance(self.input_dim,
317
                          collections.Sequence) or len(self.input_dim) < 2:
318
            raise ValueError('expected input_dim to be length >= 2 '
319
                             'e.g. (channels, height, width)')
320
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
321
        super(SpatialBatchNormalization, self)._allocate()
322
323
324
class BatchNormalizedMLP(MLP):
325
    """Convenient subclass for building an MLP with batch normalization.
326
327
    Parameters
328
    ----------
329
    conserve_memory : bool, optional
330
        See :class:`BatchNormalization`.
331
    mean_only : bool, optional
332
        See :class:`BatchNormalization`.
333
334
    Notes
335
    -----
336
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
337
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
338
    containing an appropriate :class:`BatchNormalization` brick and
339
    the activation that follows it.
340
341
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
342
    not contain any biases, as they could be canceled out by the biases
343
    in the :class:`BatchNormalization` bricks being added. Pass
344
    `use_bias` with a value of `True` if you really want this for some
345
    reason.
346
347
    """
348
    @lazy(allocation=['dims'])
349
    def __init__(self, activations, dims, *args, **kwargs):
350
        self._conserve_memory = kwargs.pop('conserve_memory', True)
351
        self._mean_only = kwargs.pop('mean_only', False)
352
        activations = [
353
            Sequence([
354
                BatchNormalization(conserve_memory=self._conserve_memory,
355
                                   mean_only=self._mean_only).apply,
356
                act.apply
357
            ], name='batch_norm_activation_{}'.format(i))
358
            for i, act in enumerate(activations)
359
        ]
360
        # Batch normalization bricks incorporate a bias, so there's no
361
        # need for our Linear bricks to have them.
362
        kwargs.setdefault('use_bias', False)
363
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
364
                                                 **kwargs)
365
366
    def _nested_brick_property_getter(self, property_name):
367
        return getattr(self, '_' + property_name)
368
369
    def _nested_brick_property_setter(self, value, property_name):
370
        setattr(self, '_' + property_name, value)
371
        for act in self.activations:
372
            assert isinstance(act.children[0], BatchNormalization)
373
            setattr(act.children[0], property_name, value)
374
375
    conserve_memory = property(partial(_nested_brick_property_getter,
376
                                       property_name='conserve_memory'),
377
                               partial(_nested_brick_property_setter,
378
                                       property_name='conserve_memory'))
379
380
    mean_only = property(partial(_nested_brick_property_getter,
381
                                 property_name='mean_only'),
382
                         partial(_nested_brick_property_setter,
383
                                 property_name='mean_only'))
384
385
    def _push_allocation_config(self):
386
        super(BatchNormalizedMLP, self)._push_allocation_config()
387
        # Do the extra allocation pushing for the BatchNormalization
388
        # bricks. They need as their input dimension the output dimension
389
        # of each linear transformation.  Exclude the first dimension,
390
        # which is the input dimension.
391
        for act, dim in equizip(self.activations, self.dims[1:]):
392
            assert isinstance(act.children[0], BatchNormalization)
393
            act.children[0].input_dim = dim
394