BatchNormalization._allocate()   F
last analyzed

Complexity

Conditions 10

Size

Total Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 10
c 2
b 0
f 0
dl 0
loc 32
rs 3.1304

How to fix   Complexity   

Complexity

Complex classes like BatchNormalization._allocate() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import shared_floatx, shared_floatx_nans, is_shared_variable
18
from .base import lazy, application
19
from .sequences import Sequence, Feedforward, MLP
20
from .interfaces import RNGMixin
21
22
23
def _add_batch_axis(var):
24
    """Prepend a singleton axis to a TensorVariable and name it."""
25
    new_var = new_var = tensor.shape_padleft(var)
26
    new_var.name = 'shape_padleft({})'.format(var.name)
27
    return new_var
28
29
30
def _add_role_and_annotate(var, role, annotations=()):
31
    """Add a role and zero or more annotations to a variable."""
32
    add_role(var, role)
33
    for annotation in annotations:
34
        add_annotation(var, annotation)
35
36
37
class BatchNormalization(RNGMixin, Feedforward):
38
    r"""Normalizes activations, parameterizes a scale and shift.
39
40
    Parameters
41
    ----------
42
    input_dim : int or tuple
43
        Shape of a single input example. It is assumed that a batch axis
44
        will be prepended to this.
45
    broadcastable : tuple, optional
46
        Tuple of the same length as `input_dim` which specifies which of
47
        the per-example axes should be averaged over to compute means and
48
        standard deviations. For example, in order to normalize over all
49
        spatial locations in a `(batch_index, channels, height, width)`
50
        image, pass `(False, True, True)`. The batch axis is always
51
        averaged out.
52
    conserve_memory : bool, optional
53
        Use an implementation that stores less intermediate state and
54
        therefore uses less memory, at the expense of 5-10% speed. Default
55
        is `True`.
56
    epsilon : float, optional
57
       The stabilizing constant for the minibatch standard deviation
58
       computation (when the brick is run in training mode).
59
       Added to the variance inside the square root, as in the
60
       batch normalization paper.
61
    scale_init : object, optional
62
        Initialization object to use for the learned scaling parameter
63
        ($\\gamma$ in [BN]_). By default, uses constant initialization
64
        of 1.
65
    shift_init : object, optional
66
        Initialization object to use for the learned shift parameter
67
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
68
    mean_only : bool, optional
69
        Perform "mean-only" batch normalization as described in [SK2016]_.
70
    learn_scale : bool, optional
71
        Whether to include a learned scale parameter ($\\gamma$ in [BN]_)
72
        in this brick. Default is `True`. Has no effect if `mean_only` is
73
        `True` (i.e. a scale parameter is never learned in mean-only mode).
74
    learn_shift : bool, optional
75
        Whether to include a learned shift parameter ($\\beta$ in [BN]_)
76
        in this brick. Default is `True`.
77
78
    Notes
79
    -----
80
    In order for trained models to behave sensibly immediately upon
81
    upon deserialization, by default, this brick runs in *inference* mode,
82
    using a population mean and population standard deviation (initialized
83
    to zeros and ones respectively) to normalize activations. It is
84
    expected that the user will adapt these during training in some
85
    fashion, independently of the training objective, e.g. by taking a
86
    moving average of minibatch-wise statistics.
87
88
    In order to *train* with batch normalization, one must obtain a
89
    training graph by transforming the original inference graph. See
90
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
91
    transform graphs, and :func:`~blocks.graph.batch_normalization`
92
    for a context manager that may enable shorter compile times
93
    (every instance of :class:`BatchNormalization` is itself a context
94
    manager, entry into which causes applications to be in minibatch
95
    "training" mode, however it is usually more convenient to use
96
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
97
    for all of your graph's :class:`BatchNormalization` bricks at once).
98
99
    Note that training in inference mode should be avoided, as this
100
    brick introduces scales and shift parameters (tagged with the
101
    `PARAMETER` role) that, in the absence of batch normalization,
102
    usually makes things unstable. If you must do this, filter for and
103
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
104
    from the list of parameters you are training, and this brick should
105
    behave as a (somewhat expensive) no-op.
106
107
    This Brick accepts `scale_init` and `shift_init` arguments but is
108
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
109
    therefore not receive pushed initialization config from any parent
110
    brick. In almost all cases, you will probably want to stick with the
111
    defaults (unit scale and zero offset), but you can explicitly pass one
112
    or both initializers to override this.
113
114
    This has the necessary properties to be inserted into a
115
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
116
    the `input_dim` should be omitted at construction, to be inferred from
117
    the layer below.
118
119
120
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
121
       accelerating deep network training by reducing internal covariate
122
       shift*. ICML (2015), pp. 448-456.
123
124
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
125
       normalization: a simple reparameterization to accelerate training
126
       of deep neural networks*. arXiv 1602.07868.
127
128
    """
129
    @lazy(allocation=['input_dim'])
130
    def __init__(self, input_dim, broadcastable=None,
131
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
132
                 shift_init=None, mean_only=False, learn_scale=True,
133
                 learn_shift=True, **kwargs):
134
        self.input_dim = input_dim
135
        self.broadcastable = broadcastable
136
        self.conserve_memory = conserve_memory
137
        self.epsilon = epsilon
138
        self.learn_scale = learn_scale
139
        self.learn_shift = learn_shift
140
        self.scale_init = (Constant(1) if scale_init is None
141
                           else scale_init)
142
        self.shift_init = (Constant(0) if shift_init is None
143
                           else shift_init)
144
        self.mean_only = mean_only
145
        self._training_mode = []
146
        super(BatchNormalization, self).__init__(**kwargs)
147
148
    @application(inputs=['input_'], outputs=['output'])
149
    def apply(self, input_, application_call):
150
        if self._training_mode:
151
            mean, stdev = self._compute_training_statistics(input_)
152
        else:
153
            mean, stdev = self._prepare_population_statistics()
154
        # Useful for filtration of calls that were already made in
155
        # training mode when doing graph transformations.
156
        # Very important to cast to bool, as self._training_mode is
157
        # normally a list (to support nested context managers), which would
158
        # otherwise get passed by reference and be remotely mutated.
159
        application_call.metadata['training_mode'] = bool(self._training_mode)
160
        # Useful for retrieving a list of updates for population
161
        # statistics. Ditch the broadcastable first axis, though, to
162
        # make it the same dimensions as the population mean and stdev
163
        # shared variables.
164
        application_call.metadata['offset'] = mean[0]
165
        application_call.metadata['divisor'] = stdev[0]
166
        # Give these quantities roles in the graph.
167
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
168
                               [self, application_call])
169
        if self.mean_only:
170
            stdev = tensor.ones_like(mean)
171
        else:
172
            # The annotation/role information is useless if it's a constant.
173
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
174
                                   [self, application_call])
175
        shift = _add_batch_axis(self.shift)
176
        scale = _add_batch_axis(self.scale)
177
        # Heavy lifting is done by the Theano utility function.
178
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
179
                                            mode=('low_mem'
180
                                                  if self.conserve_memory
181
                                                  else 'high_mem'))
182
        return normalized
183
184
    def __enter__(self):
185
        self._training_mode.append(True)
186
187
    def __exit__(self, *exc_info):
188
        self._training_mode.pop()
189
190
    def _compute_training_statistics(self, input_, axes=None):
191
        if axes is None:
192
            axes = self.normalization_axes
193
        mean = input_.mean(axis=axes, keepdims=True)
194
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
195
        if self.mean_only:
196
            stdev = tensor.ones_like(mean)
197
        else:
198
            # We already have the mean; this saves Theano the trouble of
199
            # optimizing the graph produced by tensor.var().
200
            # This two pass version is going to be slightly more stable than
201
            # E[X^2] - E[X]^2, at least when n is small. It's also never
202
            # going to be negative due to numerical error.
203
            var = tensor.mean(tensor.sqr(input_ - mean),
204
                              axis=axes, keepdims=True)
205
            eps = numpy.cast[theano.config.floatX](self.epsilon)
206
            stdev = tensor.sqrt(var + eps)
207
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
208
        return mean, stdev
209
210
    @property
211
    def normalization_axes(self):
212
        return (0,) + tuple((i + 1) for i, b in
213
                            enumerate(self.population_mean.broadcastable)
214
                            if b)
215
216
    def _prepare_population_statistics(self):
217
        mean = _add_batch_axis(self.population_mean)
218
        if self.mean_only:
219
            stdev = tensor.ones_like(self.population_mean)
220
        else:
221
            stdev = self.population_stdev
222
        stdev = _add_batch_axis(stdev)
223
        return mean, stdev
224
225
    def _allocate(self):
226
        input_dim = ((self.input_dim,)
227
                     if not isinstance(self.input_dim, collections.Sequence)
228
                     else self.input_dim)
229
        broadcastable = (tuple(False for _ in input_dim)
230
                         if self.broadcastable is None else self.broadcastable)
231
        if len(input_dim) != len(broadcastable):
232
            raise ValueError("input_dim and broadcastable must be same length")
233
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
234
                        equizip(input_dim, broadcastable))
235
        broadcastable = broadcastable
236
237
        # "beta", from the Ioffe & Szegedy manuscript.
238
        if self.learn_shift:
239
            self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
240
                                            broadcastable=broadcastable)
241
            add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
242
            self.parameters.append(self.shift)
243
        else:
244
            self.shift = tensor.constant(0, dtype=theano.config.floatX)
245
246
        if self.learn_scale and not self.mean_only:
247
            # "gamma", from the Ioffe & Szegedy manuscript.
248
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
249
                                            broadcastable=broadcastable)
250
251
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
252
            self.parameters.append(self.scale)
253
        else:
254
            self.scale = tensor.constant(1., dtype=theano.config.floatX)
255
256
        self._allocate_population_statistics(var_dim, broadcastable)
257
258
    def _allocate_population_statistics(self, var_dim, broadcastable):
259
        def _allocate_buffer(name, role, value):
260
            # These aren't technically parameters, in that they should not be
261
            # learned using the same cost function as other model parameters.
262
            population_buffer = shared_floatx(value * numpy.ones(var_dim),
263
                                              broadcastable=broadcastable,
264
                                              name=name)
265
            add_role(population_buffer, role)
266
            # Normally these would get annotated by an AnnotatingList, but they
267
            # aren't in self.parameters.
268
            add_annotation(population_buffer, self)
269
            return population_buffer
270
271
        self.population_mean = _allocate_buffer('population_mean',
272
                                                BATCH_NORM_POPULATION_MEAN,
273
                                                numpy.zeros(var_dim))
274
275
        self.population_stdev = _allocate_buffer('population_stdev',
276
                                                 BATCH_NORM_POPULATION_STDEV,
277
                                                 numpy.ones(var_dim))
278
279
    def _initialize(self):
280
        # We gate with is_shared_variable rather than relying on
281
        # learn_scale and learn_shift so as to avoid the unlikely but nasty
282
        # scenario where those flags are changed post-allocation but
283
        # pre-initialization. This ensures that such a change simply has no
284
        # effect rather than doing an inconsistent combination of things.
285
        if is_shared_variable(self.shift):
286
            self.shift_init.initialize(self.shift, self.rng)
287
        if is_shared_variable(self.scale):
288
            self.scale_init.initialize(self.scale, self.rng)
289
290
    # Needed for the Feedforward interface.
291
    @property
292
    def output_dim(self):
293
        return self.input_dim
294
295
    # The following properties allow for BatchNormalization bricks
296
    # to be used directly inside of a ConvolutionalSequence.
297
    @property
298
    def image_size(self):
299
        return self.input_dim[-2:]
300
301
    @image_size.setter
302
    def image_size(self, value):
303
        if not isinstance(self.input_dim, collections.Sequence):
304
            self.input_dim = (None,) + tuple(value)
305
        else:
306
            self.input_dim = (self.input_dim[0],) + tuple(value)
307
308
    @property
309
    def num_channels(self):
310
        return self.input_dim[0]
311
312
    @num_channels.setter
313
    def num_channels(self, value):
314
        if not isinstance(self.input_dim, collections.Sequence):
315
            self.input_dim = (value,) + (None, None)
316
        else:
317
            self.input_dim = (value,) + self.input_dim[-2:]
318
319
    def get_dim(self, name):
320
        if name in ('input', 'output'):
321
            return self.input_dim
322
        else:
323
            raise KeyError
324
325
    @property
326
    def num_output_channels(self):
327
        return self.num_channels
328
329
330
class SpatialBatchNormalization(BatchNormalization):
331
    """Convenient subclass for batch normalization across spatial inputs.
332
333
    Parameters
334
    ----------
335
    input_dim : int or tuple
336
        The input size of a single example. Must be length at least 2.
337
        It's assumed that the first axis of this tuple is a "channels"
338
        axis, which should not be summed over, and all remaining
339
        dimensions are spatial dimensions.
340
341
    Notes
342
    -----
343
    See :class:`BatchNormalization` for more details (and additional
344
    keyword arguments).
345
346
    """
347
    def _allocate(self):
348
        if not isinstance(self.input_dim,
349
                          collections.Sequence) or len(self.input_dim) < 2:
350
            raise ValueError('expected input_dim to be length >= 2 '
351
                             'e.g. (channels, height, width)')
352
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
353
        super(SpatialBatchNormalization, self)._allocate()
354
355
356
class BatchNormalizedMLP(MLP):
357
    """Convenient subclass for building an MLP with batch normalization.
358
359
    Parameters
360
    ----------
361
    conserve_memory : bool, optional, by keyword only
362
        See :class:`BatchNormalization`.
363
    mean_only : bool, optional, by keyword only
364
        See :class:`BatchNormalization`.
365
    learn_scale : bool, optional, by keyword only
366
        See :class:`BatchNormalization`.
367
    learn_shift : bool, optional, by keyword only
368
        See :class:`BatchNormalization`.
369
370
    Notes
371
    -----
372
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
373
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
374
    containing an appropriate :class:`BatchNormalization` brick and
375
    the activation that follows it.
376
377
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
378
    not contain any biases, as they could be canceled out by the biases
379
    in the :class:`BatchNormalization` bricks being added. Pass
380
    `use_bias` with a value of `True` if you really want this for some
381
    reason.
382
383
    `mean_only`, `learn_scale` and `learn_shift` are pushed down to
384
    all created :class:`BatchNormalization` bricks as allocation
385
    config.
386
387
    """
388
    @lazy(allocation=['dims'])
389
    def __init__(self, activations, dims, *args, **kwargs):
390
        self._conserve_memory = kwargs.pop('conserve_memory', True)
391
        self.mean_only = kwargs.pop('mean_only', False)
392
        self.learn_scale = kwargs.pop('learn_scale', True)
393
        self.learn_shift = kwargs.pop('learn_shift', True)
394
395
        activations = [
396
            Sequence([
397
                (BatchNormalization(conserve_memory=self._conserve_memory)
398
                 .apply),
399
                act.apply
400
            ], name='batch_norm_activation_{}'.format(i))
401
            for i, act in enumerate(activations)
402
        ]
403
        # Batch normalization bricks incorporate a bias, so there's no
404
        # need for our Linear bricks to have them.
405
        kwargs.setdefault('use_bias', False)
406
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
407
                                                 **kwargs)
408
409
    def _nested_brick_property_getter(self, property_name):
410
        return getattr(self, '_' + property_name)
411
412
    def _nested_brick_property_setter(self, value, property_name):
413
        setattr(self, '_' + property_name, value)
414
        for act in self.activations:
415
            assert isinstance(act.children[0], BatchNormalization)
416
            setattr(act.children[0], property_name, value)
417
418
    # conserve_memory is a bit special in that it can be modified
419
    # after construction/allocation and still have a valid effect on
420
    # apply(). Thus we propagate down all property sets.
421
    conserve_memory = property(partial(_nested_brick_property_getter,
422
                                       property_name='conserve_memory'),
423
                               partial(_nested_brick_property_setter,
424
                                       property_name='conserve_memory'),
425
                               doc="Conserve memory.")
426
427
    def _push_allocation_config(self):
428
        super(BatchNormalizedMLP, self)._push_allocation_config()
429
        # Do the extra allocation pushing for the BatchNormalization
430
        # bricks. They need as their input dimension the output dimension
431
        # of each linear transformation.  Exclude the first dimension,
432
        # which is the input dimension.
433
        for act, dim in equizip(self.activations, self.dims[1:]):
434
            assert isinstance(act.children[0], BatchNormalization)
435
            act.children[0].input_dim = dim
436
            for attr in ['mean_only', 'learn_scale', 'learn_shift']:
437
                setattr(act.children[0], attr, getattr(self, attr))
438