Completed
Pull Request — master (#1012)
by David
01:32
created

blocks.bricks.BatchNormalization.num_channels()   A

Complexity

Conditions 1

Size

Total Lines 6

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 6
rs 9.4285
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
import theano
6
from theano import tensor
7
from theano.tensor.nnet import bn
8
9
from ..graph import add_annotation
10
from ..initialization import Constant
11
from ..roles import (BATCH_NORM_POPULATION_MEAN,
12
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
13
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
14
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
15
                     add_role)
16
from ..utils import (shared_floatx_zeros, shared_floatx,
17
                     shared_floatx_nans)
18
from .base import lazy, application
19
from .sequences import Sequence, Feedforward, MLP
20
from .interfaces import RNGMixin
21
22
23
def _add_batch_axis(var):
24
    """Prepend a singleton axis to a TensorVariable and name it."""
25
    new_var = new_var = tensor.shape_padleft(var)
26
    new_var.name = 'shape_padleft({})'.format(var.name)
27
    return new_var
28
29
30
def _add_role_and_annotate(var, role, annotations=()):
31
    """Add a role and zero or more annotations to a variable."""
32
    add_role(var, role)
33
    for annotation in annotations:
34
        add_annotation(var, annotation)
35
36
37
class BatchNormalization(RNGMixin, Feedforward):
38
    r"""Normalizes activations, parameterizes a scale and shift.
39
40
    Parameters
41
    ----------
42
    input_dim : int or tuple
43
        Shape of a single input example. It is assumed that a batch axis
44
        will be prepended to this.
45
    broadcastable : tuple, optional
46
        Tuple the same length as `input_dim` which specifies which of the
47
        per-example axes should be averaged over to compute means and
48
        standard deviations. For example, in order to normalize over all
49
        spatial locations in a `(batch_index, channels, height, width)`
50
        image, pass `(False, True, True)`.
51
    conserve_memory : bool, optional
52
        Use an implementation that stores less intermediate state and
53
        therefore uses less memory, at the expense of 5-10% speed. Default
54
        is `True`.
55
    epsilon : float, optional
56
       The stabilizing constant for the minibatch standard deviation
57
       computation (when the brick is run in training mode).
58
       Added to the variance inside the square root, as in the
59
       batch normalization paper.
60
    scale_init : object, optional
61
        Initialization object to use for the learned scaling parameter
62
        ($\\gamma$ in [BN]_). By default, uses constant initialization
63
        of 1.
64
    shift_init : object, optional
65
        Initialization object to use for the learned shift parameter
66
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
67
    mean_only : bool, optional
68
        Perform "mean-only" batch normalization as described in [SK2016]_.
69
70
    Notes
71
    -----
72
    In order for trained models to behave sensibly immediately upon
73
    upon deserialization, by default, this brick runs in *inference* mode,
74
    using a population mean and population standard deviation (initialized
75
    to zeros and ones respectively) to normalize activations. It is
76
    expected that the user will adapt these during training in some
77
    fashion, independently of the training objective, e.g. by taking a
78
    moving average of minibatch-wise statistics.
79
80
    In order to *train* with batch normalization, one must obtain a
81
    training graph by transforming the original inference graph. See
82
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
83
    transform graphs, and :func:`~blocks.graph.batch_normalization`
84
    for a context manager that may enable shorter compile times
85
    (every instance of :class:`BatchNormalization` is itself a context
86
    manager, entry into which causes applications to be in minibatch
87
    "training" mode, however it is usually more convenient to use
88
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
89
    for all of your graph's :class:`BatchNormalization` bricks at once).
90
91
    Note that training in inference mode should be avoided, as this
92
    brick introduces scales and shift parameters (tagged with the
93
    `PARAMETER` role) that, in the absence of batch normalization,
94
    usually makes things unstable. If you must do this, filter for and
95
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
96
    from the list of parameters you are training, and this brick should
97
    behave as a (somewhat expensive) no-op.
98
99
    This Brick accepts `scale_init` and `shift_init` arguments but is
100
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
101
    therefore not receive pushed initialization config from any parent
102
    brick. In almost all cases, you will probably want to stick with the
103
    defaults (unit scale and zero offset), but you can explicitly pass one
104
    or both initializers to override this.
105
106
    This has the necessary properties to be inserted into a
107
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
108
    the `input_dim` should be omitted at construction, to be inferred from
109
    the layer below.
110
111
112
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
113
       accelerating deep network training by reducing internal covariate
114
       shift*. ICML (2015), pp. 448-456.
115
116
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
117
       normalization: a simple reparameterization to accelerate training
118
       of deep neural networks*. arXiv 1602.07868.
119
120
    """
121
    @lazy(allocation=['input_dim'])
122
    def __init__(self, input_dim, broadcastable=None,
123
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
124
                 shift_init=None, mean_only=False, **kwargs):
125
        self.input_dim = input_dim
126
        self.broadcastable = broadcastable
127
        self.conserve_memory = conserve_memory
128
        self.epsilon = epsilon
129
        self.scale_init = (Constant(1) if scale_init is None
130
                           else scale_init)
131
        self.shift_init = (Constant(0) if shift_init is None
132
                           else shift_init)
133
        self.mean_only = mean_only
134
        self._training_mode = []
135
        super(BatchNormalization, self).__init__(**kwargs)
136
137
    @application(inputs=['input_'], outputs=['output'])
138
    def apply(self, input_, application_call):
139
        if self._training_mode:
140
            mean, stdev = self._compute_training_statistics(input_)
141
        else:
142
            mean, stdev = self._prepare_population_statistics()
143
        if self.mean_only:
144
            # We will be using the Theano support code for doing the
145
            # actual batch normalization and we need something for these
146
            # arguments, so just use symbolic 1s.
147
            scale = tensor.ones_like(self.shift)
148
            stdev = tensor.ones_like(stdev)
149
        else:
150
            scale = self.scale
151
        # Useful for filtration of calls that were already made in
152
        # training mode when doing graph transformations.
153
        # Very important to cast to bool, as self._training_mode is
154
        # normally a list (to support nested context managers), which would
155
        # otherwise get passed by reference and be remotely mutated.
156
        application_call.metadata['training_mode'] = bool(self._training_mode)
157
        # Useful for retrieving a list of updates for population
158
        # statistics. Ditch the broadcastable first axis, though, to
159
        # make it the same dimensions as the population mean and stdev
160
        # shared variables.
161
        application_call.metadata['offset'] = mean[0]
162
        application_call.metadata['divisor'] = stdev[0]
163
        # Give these quantities roles in the graph.
164
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
165
                               [self, application_call])
166
        if not self.mean_only:
167
            # The annotation/role information is useless if it's a constant.
168
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
169
                                   [self, application_call])
170
        shift = _add_batch_axis(self.shift)
171
        scale = _add_batch_axis(scale)
172
        # Heavy lifting is done by the Theano utility function.
173
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
174
                                            mode=('low_mem'
175
                                                  if self.conserve_memory
176
                                                  else 'high_mem'))
177
        return normalized
178
179
    def __enter__(self):
180
        self._training_mode.append(True)
181
182
    def __exit__(self, *exc_info):
183
        self._training_mode.pop()
184
185
    def _compute_training_statistics(self, input_):
186
        axes = (0,) + tuple((i + 1) for i, b in
187
                            enumerate(self.population_mean.broadcastable)
188
                            if b)
189
        mean = input_.mean(axis=axes, keepdims=True)
190
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
191
        stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
192
                            numpy.cast[theano.config.floatX](self.epsilon))
193
        assert stdev.broadcastable[1:] == self.population_stdev.broadcastable
194
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
195
        add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
196
        return mean, stdev
197
198
    def _prepare_population_statistics(self):
199
        mean = _add_batch_axis(self.population_mean)
200
        stdev = _add_batch_axis(self.population_stdev)
201
        return mean, stdev
202
203
    def _allocate(self):
204
        input_dim = ((self.input_dim,)
205
                     if not isinstance(self.input_dim, collections.Sequence)
206
                     else self.input_dim)
207
        broadcastable = (tuple(False for _ in input_dim)
208
                         if self.broadcastable is None else self.broadcastable)
209
        if len(input_dim) != len(broadcastable):
210
            raise ValueError("input_dim and broadcastable must be same length")
211
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
212
                        equizip(input_dim, broadcastable))
213
        broadcastable = broadcastable
214
215
        # "beta", from the Ioffe & Szegedy manuscript.
216
        self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
217
                                        broadcastable=broadcastable)
218
        add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
219
        self.parameters.append(self.shift)
220
221
        if not self.mean_only:
222
            # "gamma", from the Ioffe & Szegedy manuscript.
223
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
224
                                            broadcastable=broadcastable)
225
226
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
227
            self.parameters.append(self.scale)
228
229
        # These aren't technically parameters, in that they should not be
230
        # learned using the same cost function as other model parameters.
231
        self.population_mean = shared_floatx_zeros(var_dim,
232
                                                   name='population_mean',
233
                                                   broadcastable=broadcastable)
234
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
235
        # Normally these would get annotated by an AnnotatingList, but they
236
        # aren't in self.parameters.
237
        add_annotation(self.population_mean, self)
238
239
        if not self.mean_only:
240
            self.population_stdev = shared_floatx(numpy.ones(var_dim),
241
                                                  name='population_stdev',
242
                                                  broadcastable=broadcastable)
243
            add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
244
            # Normally these would get annotated by an AnnotatingList, but they
245
            # aren't in self.parameters.
246
            add_annotation(self.population_stdev, self)
247
248
    def _initialize(self):
249
        self.shift_init.initialize(self.shift, self.rng)
250
        if not self.mean_only:
251
            self.scale_init.initialize(self.scale, self.rng)
252
253
    # Needed for the Feedforward interface.
254
    @property
255
    def output_dim(self):
256
        return self.input_dim
257
258
    # The following properties allow for BatchNormalization bricks
259
    # to be used directly inside of a ConvolutionalSequence.
260
    @property
261
    def image_size(self):
262
        return self.input_dim[-2:]
263
264
    @image_size.setter
265
    def image_size(self, value):
266
        if not isinstance(self.input_dim, collections.Sequence):
267
            self.input_dim = (None,) + tuple(value)
268
        else:
269
            self.input_dim = (self.input_dim[0],) + tuple(value)
270
271
    @property
272
    def num_channels(self):
273
        return self.input_dim[0]
274
275
    @num_channels.setter
276
    def num_channels(self, value):
277
        if not isinstance(self.input_dim, collections.Sequence):
278
            self.input_dim = (value,) + (None, None)
279
        else:
280
            self.input_dim = (value,) + self.input_dim[-2:]
281
282
    def get_dim(self, name):
283
        if name in ('input', 'output'):
284
            return self.input_dim
285
        else:
286
            raise KeyError
287
288
    @property
289
    def num_output_channels(self):
290
        return self.num_channels
291
292
293
class SpatialBatchNormalization(BatchNormalization):
294
    """Convenient subclass for batch normalization across spatial inputs.
295
296
    Parameters
297
    ----------
298
    input_dim : int or tuple
299
        The input size of a single example. Must be length at least 2.
300
        It's assumed that the first axis of this tuple is a "channels"
301
        axis, which should not be summed over, and all remaining
302
        dimensions are spatial dimensions.
303
304
    Notes
305
    -----
306
    See :class:`BatchNormalization` for more details (and additional
307
    keyword arguments).
308
309
    """
310
    def _allocate(self):
311
        if not isinstance(self.input_dim,
312
                          collections.Sequence) or len(self.input_dim) < 2:
313
            raise ValueError('expected input_dim to be length >= 2 '
314
                             'e.g. (channels, height, width)')
315
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
316
        super(SpatialBatchNormalization, self)._allocate()
317
318
319
class BatchNormalizedMLP(MLP):
320
    """Convenient subclass for building an MLP with batch normalization.
321
322
    Parameters
323
    ----------
324
    conserve_memory : bool, optional
325
        See :class:`BatchNormalization`.
326
327
    Notes
328
    -----
329
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
330
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
331
    containing an appropriate :class:`BatchNormalization` brick and
332
    the activation that follows it.
333
334
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
335
    not contain any biases, as they could be canceled out by the biases
336
    in the :class:`BatchNormalization` bricks being added. Pass
337
    `use_bias` with a value of `True` if you really want this for some
338
    reason.
339
340
    """
341
    @lazy(allocation=['dims'])
342
    def __init__(self, activations, dims, *args, **kwargs):
343
        conserve_memory = kwargs.pop('conserve_memory', True)
344
        activations = [
345
            Sequence([
346
                BatchNormalization(conserve_memory=conserve_memory).apply,
347
                act.apply
348
            ], name='batch_norm_activation_{}'.format(i))
349
            for i, act in enumerate(activations)
350
        ]
351
        # Batch normalization bricks incorporate a bias, so there's no
352
        # need for our Linear bricks to have them.
353
        kwargs.setdefault('use_bias', False)
354
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
355
                                                 **kwargs)
356
357
    @property
358
    def conserve_memory(self):
359
        return self._conserve_memory
360
361
    @conserve_memory.setter
362
    def conserve_memory(self, value):
363
        self._conserve_memory = value
364
        for act in self.activations:
365
            assert isinstance(act.children[0], BatchNormalization)
366
            act.children[0].conserve_memory = value
367
368
    def _push_allocation_config(self):
369
        super(BatchNormalizedMLP, self)._push_allocation_config()
370
        # Do the extra allocation pushing for the BatchNormalization
371
        # bricks. They need as their input dimension the output dimension
372
        # of each linear transformation.  Exclude the first dimension,
373
        # which is the input dimension.
374
        for act, dim in equizip(self.activations, self.dims[1:]):
375
            assert isinstance(act.children[0], BatchNormalization)
376
            act.children[0].input_dim = dim
377