Completed
Push — master ( 67f35d...82886a )
by David
04:44
created

BatchNormalization._allocate()   F

Complexity

Conditions 11

Size

Total Lines 48

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 11
c 1
b 0
f 0
dl 0
loc 48
rs 3.1764

How to fix   Complexity   

Complexity

Complex classes like BatchNormalization._allocate() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import (shared_floatx_zeros, shared_floatx,
18
                     shared_floatx_nans, is_shared_variable)
19
from .base import lazy, application
20
from .sequences import Sequence, Feedforward, MLP
21
from .interfaces import RNGMixin
22
23
24
def _add_batch_axis(var):
25
    """Prepend a singleton axis to a TensorVariable and name it."""
26
    new_var = new_var = tensor.shape_padleft(var)
27
    new_var.name = 'shape_padleft({})'.format(var.name)
28
    return new_var
29
30
31
def _add_role_and_annotate(var, role, annotations=()):
32
    """Add a role and zero or more annotations to a variable."""
33
    add_role(var, role)
34
    for annotation in annotations:
35
        add_annotation(var, annotation)
36
37
38
class BatchNormalization(RNGMixin, Feedforward):
39
    r"""Normalizes activations, parameterizes a scale and shift.
40
41
    Parameters
42
    ----------
43
    input_dim : int or tuple
44
        Shape of a single input example. It is assumed that a batch axis
45
        will be prepended to this.
46
    broadcastable : tuple, optional
47
        Tuple of the same length as `input_dim` which specifies which of
48
        the per-example axes should be averaged over to compute means and
49
        standard deviations. For example, in order to normalize over all
50
        spatial locations in a `(batch_index, channels, height, width)`
51
        image, pass `(False, True, True)`. The batch axis is always
52
        averaged out.
53
    conserve_memory : bool, optional
54
        Use an implementation that stores less intermediate state and
55
        therefore uses less memory, at the expense of 5-10% speed. Default
56
        is `True`.
57
    epsilon : float, optional
58
       The stabilizing constant for the minibatch standard deviation
59
       computation (when the brick is run in training mode).
60
       Added to the variance inside the square root, as in the
61
       batch normalization paper.
62
    scale_init : object, optional
63
        Initialization object to use for the learned scaling parameter
64
        ($\\gamma$ in [BN]_). By default, uses constant initialization
65
        of 1.
66
    shift_init : object, optional
67
        Initialization object to use for the learned shift parameter
68
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
69
    mean_only : bool, optional
70
        Perform "mean-only" batch normalization as described in [SK2016]_.
71
    learn_scale : bool, optional
72
        Whether to include a learned scale parameter ($\\gamma$ in [BN]_)
73
        in this brick. Default is `True`. Has no effect if `mean_only` is
74
        `True` (i.e. a scale parameter is never learned in mean-only mode).
75
    learn_shift : bool, optional
76
        Whether to include a learned shift parameter ($\\beta$ in [BN]_)
77
        in this brick. Default is `True`.
78
79
    Notes
80
    -----
81
    In order for trained models to behave sensibly immediately upon
82
    upon deserialization, by default, this brick runs in *inference* mode,
83
    using a population mean and population standard deviation (initialized
84
    to zeros and ones respectively) to normalize activations. It is
85
    expected that the user will adapt these during training in some
86
    fashion, independently of the training objective, e.g. by taking a
87
    moving average of minibatch-wise statistics.
88
89
    In order to *train* with batch normalization, one must obtain a
90
    training graph by transforming the original inference graph. See
91
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
92
    transform graphs, and :func:`~blocks.graph.batch_normalization`
93
    for a context manager that may enable shorter compile times
94
    (every instance of :class:`BatchNormalization` is itself a context
95
    manager, entry into which causes applications to be in minibatch
96
    "training" mode, however it is usually more convenient to use
97
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
98
    for all of your graph's :class:`BatchNormalization` bricks at once).
99
100
    Note that training in inference mode should be avoided, as this
101
    brick introduces scales and shift parameters (tagged with the
102
    `PARAMETER` role) that, in the absence of batch normalization,
103
    usually makes things unstable. If you must do this, filter for and
104
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
105
    from the list of parameters you are training, and this brick should
106
    behave as a (somewhat expensive) no-op.
107
108
    This Brick accepts `scale_init` and `shift_init` arguments but is
109
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
110
    therefore not receive pushed initialization config from any parent
111
    brick. In almost all cases, you will probably want to stick with the
112
    defaults (unit scale and zero offset), but you can explicitly pass one
113
    or both initializers to override this.
114
115
    This has the necessary properties to be inserted into a
116
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
117
    the `input_dim` should be omitted at construction, to be inferred from
118
    the layer below.
119
120
121
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
122
       accelerating deep network training by reducing internal covariate
123
       shift*. ICML (2015), pp. 448-456.
124
125
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
126
       normalization: a simple reparameterization to accelerate training
127
       of deep neural networks*. arXiv 1602.07868.
128
129
    """
130
    @lazy(allocation=['input_dim'])
131
    def __init__(self, input_dim, broadcastable=None,
132
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
133
                 shift_init=None, mean_only=False, learn_scale=True,
134
                 learn_shift=True, **kwargs):
135
        self.input_dim = input_dim
136
        self.broadcastable = broadcastable
137
        self.conserve_memory = conserve_memory
138
        self.epsilon = epsilon
139
        self.learn_scale = learn_scale
140
        self.learn_shift = learn_shift
141
        self.scale_init = (Constant(1) if scale_init is None
142
                           else scale_init)
143
        self.shift_init = (Constant(0) if shift_init is None
144
                           else shift_init)
145
        self.mean_only = mean_only
146
        self._training_mode = []
147
        super(BatchNormalization, self).__init__(**kwargs)
148
149
    @application(inputs=['input_'], outputs=['output'])
150
    def apply(self, input_, application_call):
151
        if self._training_mode:
152
            mean, stdev = self._compute_training_statistics(input_)
153
        else:
154
            mean, stdev = self._prepare_population_statistics()
155
        # Useful for filtration of calls that were already made in
156
        # training mode when doing graph transformations.
157
        # Very important to cast to bool, as self._training_mode is
158
        # normally a list (to support nested context managers), which would
159
        # otherwise get passed by reference and be remotely mutated.
160
        application_call.metadata['training_mode'] = bool(self._training_mode)
161
        # Useful for retrieving a list of updates for population
162
        # statistics. Ditch the broadcastable first axis, though, to
163
        # make it the same dimensions as the population mean and stdev
164
        # shared variables.
165
        application_call.metadata['offset'] = mean[0]
166
        application_call.metadata['divisor'] = stdev[0]
167
        # Give these quantities roles in the graph.
168
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
169
                               [self, application_call])
170
        if self.mean_only:
171
            stdev = tensor.ones_like(mean)
172
        else:
173
            # The annotation/role information is useless if it's a constant.
174
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
175
                                   [self, application_call])
176
        shift = _add_batch_axis(self.shift)
177
        scale = _add_batch_axis(self.scale)
178
        # Heavy lifting is done by the Theano utility function.
179
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
180
                                            mode=('low_mem'
181
                                                  if self.conserve_memory
182
                                                  else 'high_mem'))
183
        return normalized
184
185
    def __enter__(self):
186
        self._training_mode.append(True)
187
188
    def __exit__(self, *exc_info):
189
        self._training_mode.pop()
190
191
    def _compute_training_statistics(self, input_):
192
        axes = (0,) + tuple((i + 1) for i, b in
193
                            enumerate(self.population_mean.broadcastable)
194
                            if b)
195
        mean = input_.mean(axis=axes, keepdims=True)
196
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
197
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
198
        if self.mean_only:
199
            stdev = tensor.ones_like(mean)
200
        else:
201
            var = (tensor.sqr(input_).mean(axis=axes, keepdims=True) -
202
                   tensor.sqr(mean))
203
            eps = numpy.cast[theano.config.floatX](self.epsilon)
204
            stdev = tensor.sqrt(var + eps)
205
            assert (stdev.broadcastable[1:] ==
206
                    self.population_stdev.broadcastable)
207
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
208
        return mean, stdev
209
210
    def _prepare_population_statistics(self):
211
        mean = _add_batch_axis(self.population_mean)
212
        if self.mean_only:
213
            stdev = tensor.ones_like(self.population_mean)
214
        else:
215
            stdev = self.population_stdev
216
        stdev = _add_batch_axis(stdev)
217
        return mean, stdev
218
219
    def _allocate(self):
220
        input_dim = ((self.input_dim,)
221
                     if not isinstance(self.input_dim, collections.Sequence)
222
                     else self.input_dim)
223
        broadcastable = (tuple(False for _ in input_dim)
224
                         if self.broadcastable is None else self.broadcastable)
225
        if len(input_dim) != len(broadcastable):
226
            raise ValueError("input_dim and broadcastable must be same length")
227
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
228
                        equizip(input_dim, broadcastable))
229
        broadcastable = broadcastable
230
231
        # "beta", from the Ioffe & Szegedy manuscript.
232
        if self.learn_shift:
233
            self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
234
                                            broadcastable=broadcastable)
235
            add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
236
            self.parameters.append(self.shift)
237
        else:
238
            self.shift = tensor.constant(0, dtype=theano.config.floatX)
239
240
        # These aren't technically parameters, in that they should not be
241
        # learned using the same cost function as other model parameters.
242
        self.population_mean = shared_floatx_zeros(var_dim,
243
                                                   name='population_mean',
244
                                                   broadcastable=broadcastable)
245
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
246
247
        # Normally these would get annotated by an AnnotatingList, but they
248
        # aren't in self.parameters.
249
        add_annotation(self.population_mean, self)
250
251
        if self.learn_scale and not self.mean_only:
252
            # "gamma", from the Ioffe & Szegedy manuscript.
253
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
254
                                            broadcastable=broadcastable)
255
256
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
257
            self.parameters.append(self.scale)
258
        else:
259
            self.scale = tensor.constant(1., dtype=theano.config.floatX)
260
261
        if not self.mean_only:
262
            self.population_stdev = shared_floatx(numpy.ones(var_dim),
263
                                                  name='population_stdev',
264
                                                  broadcastable=broadcastable)
265
            add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
266
            add_annotation(self.population_stdev, self)
267
268
    def _initialize(self):
269
        # We gate with is_shared_variable rather than relying on
270
        # learn_scale and learn_shift so as to avoid the unlikely but nasty
271
        # scenario where those flags are changed post-allocation but
272
        # pre-initialization. This ensures that such a change simply has no
273
        # effect rather than doing an inconsistent combination of things.
274
        if is_shared_variable(self.shift):
275
            self.shift_init.initialize(self.shift, self.rng)
276
        if is_shared_variable(self.scale):
277
            self.scale_init.initialize(self.scale, self.rng)
278
279
    # Needed for the Feedforward interface.
280
    @property
281
    def output_dim(self):
282
        return self.input_dim
283
284
    # The following properties allow for BatchNormalization bricks
285
    # to be used directly inside of a ConvolutionalSequence.
286
    @property
287
    def image_size(self):
288
        return self.input_dim[-2:]
289
290
    @image_size.setter
291
    def image_size(self, value):
292
        if not isinstance(self.input_dim, collections.Sequence):
293
            self.input_dim = (None,) + tuple(value)
294
        else:
295
            self.input_dim = (self.input_dim[0],) + tuple(value)
296
297
    @property
298
    def num_channels(self):
299
        return self.input_dim[0]
300
301
    @num_channels.setter
302
    def num_channels(self, value):
303
        if not isinstance(self.input_dim, collections.Sequence):
304
            self.input_dim = (value,) + (None, None)
305
        else:
306
            self.input_dim = (value,) + self.input_dim[-2:]
307
308
    def get_dim(self, name):
309
        if name in ('input', 'output'):
310
            return self.input_dim
311
        else:
312
            raise KeyError
313
314
    @property
315
    def num_output_channels(self):
316
        return self.num_channels
317
318
319
class SpatialBatchNormalization(BatchNormalization):
320
    """Convenient subclass for batch normalization across spatial inputs.
321
322
    Parameters
323
    ----------
324
    input_dim : int or tuple
325
        The input size of a single example. Must be length at least 2.
326
        It's assumed that the first axis of this tuple is a "channels"
327
        axis, which should not be summed over, and all remaining
328
        dimensions are spatial dimensions.
329
330
    Notes
331
    -----
332
    See :class:`BatchNormalization` for more details (and additional
333
    keyword arguments).
334
335
    """
336
    def _allocate(self):
337
        if not isinstance(self.input_dim,
338
                          collections.Sequence) or len(self.input_dim) < 2:
339
            raise ValueError('expected input_dim to be length >= 2 '
340
                             'e.g. (channels, height, width)')
341
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
342
        super(SpatialBatchNormalization, self)._allocate()
343
344
345
class BatchNormalizedMLP(MLP):
346
    """Convenient subclass for building an MLP with batch normalization.
347
348
    Parameters
349
    ----------
350
    conserve_memory : bool, optional, by keyword only
351
        See :class:`BatchNormalization`.
352
    mean_only : bool, optional, by keyword only
353
        See :class:`BatchNormalization`.
354
    learn_scale : bool, optional, by keyword only
355
        See :class:`BatchNormalization`.
356
    learn_shift : bool, optional, by keyword only
357
        See :class:`BatchNormalization`.
358
359
    Notes
360
    -----
361
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
362
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
363
    containing an appropriate :class:`BatchNormalization` brick and
364
    the activation that follows it.
365
366
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
367
    not contain any biases, as they could be canceled out by the biases
368
    in the :class:`BatchNormalization` bricks being added. Pass
369
    `use_bias` with a value of `True` if you really want this for some
370
    reason.
371
372
    `mean_only`, `learn_scale` and `learn_shift` are pushed down to
373
    all created :class:`BatchNormalization` bricks as allocation
374
    config.
375
376
    """
377
    @lazy(allocation=['dims'])
378
    def __init__(self, activations, dims, *args, **kwargs):
379
        self._conserve_memory = kwargs.pop('conserve_memory', True)
380
        self.mean_only = kwargs.pop('mean_only', False)
381
        self.learn_scale = kwargs.pop('learn_scale', True)
382
        self.learn_shift = kwargs.pop('learn_shift', True)
383
384
        activations = [
385
            Sequence([
386
                (BatchNormalization(conserve_memory=self._conserve_memory)
387
                 .apply),
388
                act.apply
389
            ], name='batch_norm_activation_{}'.format(i))
390
            for i, act in enumerate(activations)
391
        ]
392
        # Batch normalization bricks incorporate a bias, so there's no
393
        # need for our Linear bricks to have them.
394
        kwargs.setdefault('use_bias', False)
395
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
396
                                                 **kwargs)
397
398
    def _nested_brick_property_getter(self, property_name):
399
        return getattr(self, '_' + property_name)
400
401
    def _nested_brick_property_setter(self, value, property_name):
402
        setattr(self, '_' + property_name, value)
403
        for act in self.activations:
404
            assert isinstance(act.children[0], BatchNormalization)
405
            setattr(act.children[0], property_name, value)
406
407
    # conserve_memory is a bit special in that it can be modified
408
    # after construction/allocation and still have a valid effect on
409
    # apply(). Thus we propagate down all property sets.
410
    conserve_memory = property(partial(_nested_brick_property_getter,
411
                                       property_name='conserve_memory'),
412
                               partial(_nested_brick_property_setter,
413
                                       property_name='conserve_memory'))
414
415
    def _push_allocation_config(self):
416
        super(BatchNormalizedMLP, self)._push_allocation_config()
417
        # Do the extra allocation pushing for the BatchNormalization
418
        # bricks. They need as their input dimension the output dimension
419
        # of each linear transformation.  Exclude the first dimension,
420
        # which is the input dimension.
421
        for act, dim in equizip(self.activations, self.dims[1:]):
422
            assert isinstance(act.children[0], BatchNormalization)
423
            act.children[0].input_dim = dim
424
            for attr in ['mean_only', 'learn_scale', 'learn_shift']:
425
                setattr(act.children[0], attr, getattr(self, attr))
426