Completed
Push — master ( 76676f...e3a1ef )
by
unknown
03:29
created

BatchNormalization   B

Complexity

Total Complexity 42

Size/Duplication

Total Lines 286
Duplicated Lines 0 %

Importance

Changes 7
Bugs 1 Features 0
Metric Value
c 7
b 1
f 0
dl 0
loc 286
rs 8.295
wmc 42

16 Methods

Rating   Name   Duplication   Size   Complexity  
B apply() 0 35 4
A __exit__() 0 2 1
A __init__() 0 18 3
A __enter__() 0 2 1
A image_size() 0 3 2
A _compute_training_statistics() 0 14 3
F _allocate() 0 32 10
A normalization_axes() 0 5 3
A get_dim() 0 5 2
A _allocate_population_statistics() 0 20 2
A num_channels() 0 3 2
A num_output_channels() 0 3 1
A _prepare_population_statistics() 0 8 2
A _initialize() 0 10 3
A output_dim() 0 3 1
A _allocate_buffer() 0 11 1

How to fix   Complexity   

Complex Class

Complex classes like BatchNormalization often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import shared_floatx, shared_floatx_nans, is_shared_variable
18
from .base import lazy, application
19
from .sequences import Sequence, Feedforward, MLP
20
from .interfaces import RNGMixin
21
22
23
def _add_batch_axis(var):
24
    """Prepend a singleton axis to a TensorVariable and name it."""
25
    new_var = new_var = tensor.shape_padleft(var)
26
    new_var.name = 'shape_padleft({})'.format(var.name)
27
    return new_var
28
29
30
def _add_role_and_annotate(var, role, annotations=()):
31
    """Add a role and zero or more annotations to a variable."""
32
    add_role(var, role)
33
    for annotation in annotations:
34
        add_annotation(var, annotation)
35
36
37
class BatchNormalization(RNGMixin, Feedforward):
38
    r"""Normalizes activations, parameterizes a scale and shift.
39
40
    Parameters
41
    ----------
42
    input_dim : int or tuple
43
        Shape of a single input example. It is assumed that a batch axis
44
        will be prepended to this.
45
    broadcastable : tuple, optional
46
        Tuple of the same length as `input_dim` which specifies which of
47
        the per-example axes should be averaged over to compute means and
48
        standard deviations. For example, in order to normalize over all
49
        spatial locations in a `(batch_index, channels, height, width)`
50
        image, pass `(False, True, True)`. The batch axis is always
51
        averaged out.
52
    conserve_memory : bool, optional
53
        Use an implementation that stores less intermediate state and
54
        therefore uses less memory, at the expense of 5-10% speed. Default
55
        is `True`.
56
    epsilon : float, optional
57
       The stabilizing constant for the minibatch standard deviation
58
       computation (when the brick is run in training mode).
59
       Added to the variance inside the square root, as in the
60
       batch normalization paper.
61
    scale_init : object, optional
62
        Initialization object to use for the learned scaling parameter
63
        ($\\gamma$ in [BN]_). By default, uses constant initialization
64
        of 1.
65
    shift_init : object, optional
66
        Initialization object to use for the learned shift parameter
67
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
68
    mean_only : bool, optional
69
        Perform "mean-only" batch normalization as described in [SK2016]_.
70
    learn_scale : bool, optional
71
        Whether to include a learned scale parameter ($\\gamma$ in [BN]_)
72
        in this brick. Default is `True`. Has no effect if `mean_only` is
73
        `True` (i.e. a scale parameter is never learned in mean-only mode).
74
    learn_shift : bool, optional
75
        Whether to include a learned shift parameter ($\\beta$ in [BN]_)
76
        in this brick. Default is `True`.
77
78
    Notes
79
    -----
80
    In order for trained models to behave sensibly immediately upon
81
    upon deserialization, by default, this brick runs in *inference* mode,
82
    using a population mean and population standard deviation (initialized
83
    to zeros and ones respectively) to normalize activations. It is
84
    expected that the user will adapt these during training in some
85
    fashion, independently of the training objective, e.g. by taking a
86
    moving average of minibatch-wise statistics.
87
88
    In order to *train* with batch normalization, one must obtain a
89
    training graph by transforming the original inference graph. See
90
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
91
    transform graphs, and :func:`~blocks.graph.batch_normalization`
92
    for a context manager that may enable shorter compile times
93
    (every instance of :class:`BatchNormalization` is itself a context
94
    manager, entry into which causes applications to be in minibatch
95
    "training" mode, however it is usually more convenient to use
96
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
97
    for all of your graph's :class:`BatchNormalization` bricks at once).
98
99
    Note that training in inference mode should be avoided, as this
100
    brick introduces scales and shift parameters (tagged with the
101
    `PARAMETER` role) that, in the absence of batch normalization,
102
    usually makes things unstable. If you must do this, filter for and
103
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
104
    from the list of parameters you are training, and this brick should
105
    behave as a (somewhat expensive) no-op.
106
107
    This Brick accepts `scale_init` and `shift_init` arguments but is
108
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
109
    therefore not receive pushed initialization config from any parent
110
    brick. In almost all cases, you will probably want to stick with the
111
    defaults (unit scale and zero offset), but you can explicitly pass one
112
    or both initializers to override this.
113
114
    This has the necessary properties to be inserted into a
115
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
116
    the `input_dim` should be omitted at construction, to be inferred from
117
    the layer below.
118
119
120
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
121
       accelerating deep network training by reducing internal covariate
122
       shift*. ICML (2015), pp. 448-456.
123
124
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
125
       normalization: a simple reparameterization to accelerate training
126
       of deep neural networks*. arXiv 1602.07868.
127
128
    """
129
    @lazy(allocation=['input_dim'])
130
    def __init__(self, input_dim, broadcastable=None,
131
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
132
                 shift_init=None, mean_only=False, learn_scale=True,
133
                 learn_shift=True, **kwargs):
134
        self.input_dim = input_dim
135
        self.broadcastable = broadcastable
136
        self.conserve_memory = conserve_memory
137
        self.epsilon = epsilon
138
        self.learn_scale = learn_scale
139
        self.learn_shift = learn_shift
140
        self.scale_init = (Constant(1) if scale_init is None
141
                           else scale_init)
142
        self.shift_init = (Constant(0) if shift_init is None
143
                           else shift_init)
144
        self.mean_only = mean_only
145
        self._training_mode = []
146
        super(BatchNormalization, self).__init__(**kwargs)
147
148
    @application(inputs=['input_'], outputs=['output'])
149
    def apply(self, input_, application_call):
150
        if self._training_mode:
151
            mean, stdev = self._compute_training_statistics(input_)
152
        else:
153
            mean, stdev = self._prepare_population_statistics()
154
        # Useful for filtration of calls that were already made in
155
        # training mode when doing graph transformations.
156
        # Very important to cast to bool, as self._training_mode is
157
        # normally a list (to support nested context managers), which would
158
        # otherwise get passed by reference and be remotely mutated.
159
        application_call.metadata['training_mode'] = bool(self._training_mode)
160
        # Useful for retrieving a list of updates for population
161
        # statistics. Ditch the broadcastable first axis, though, to
162
        # make it the same dimensions as the population mean and stdev
163
        # shared variables.
164
        application_call.metadata['offset'] = mean[0]
165
        application_call.metadata['divisor'] = stdev[0]
166
        # Give these quantities roles in the graph.
167
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
168
                               [self, application_call])
169
        if self.mean_only:
170
            stdev = tensor.ones_like(mean)
171
        else:
172
            # The annotation/role information is useless if it's a constant.
173
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
174
                                   [self, application_call])
175
        shift = _add_batch_axis(self.shift)
176
        scale = _add_batch_axis(self.scale)
177
        # Heavy lifting is done by the Theano utility function.
178
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
179
                                            mode=('low_mem'
180
                                                  if self.conserve_memory
181
                                                  else 'high_mem'))
182
        return normalized
183
184
    def __enter__(self):
185
        self._training_mode.append(True)
186
187
    def __exit__(self, *exc_info):
188
        self._training_mode.pop()
189
190
    def _compute_training_statistics(self, input_, axes=None):
191
        if axes is None:
192
            axes = self.normalization_axes
193
        mean = input_.mean(axis=axes, keepdims=True)
194
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
195
        if self.mean_only:
196
            stdev = tensor.ones_like(mean)
197
        else:
198
            var = (tensor.sqr(input_).mean(axis=axes, keepdims=True) -
199
                   tensor.sqr(mean))
200
            eps = numpy.cast[theano.config.floatX](self.epsilon)
201
            stdev = tensor.sqrt(var + eps)
202
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
203
        return mean, stdev
204
205
    @property
206
    def normalization_axes(self):
207
        return (0,) + tuple((i + 1) for i, b in
208
                            enumerate(self.population_mean.broadcastable)
209
                            if b)
210
211
    def _prepare_population_statistics(self):
212
        mean = _add_batch_axis(self.population_mean)
213
        if self.mean_only:
214
            stdev = tensor.ones_like(self.population_mean)
215
        else:
216
            stdev = self.population_stdev
217
        stdev = _add_batch_axis(stdev)
218
        return mean, stdev
219
220
    def _allocate(self):
221
        input_dim = ((self.input_dim,)
222
                     if not isinstance(self.input_dim, collections.Sequence)
223
                     else self.input_dim)
224
        broadcastable = (tuple(False for _ in input_dim)
225
                         if self.broadcastable is None else self.broadcastable)
226
        if len(input_dim) != len(broadcastable):
227
            raise ValueError("input_dim and broadcastable must be same length")
228
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
229
                        equizip(input_dim, broadcastable))
230
        broadcastable = broadcastable
231
232
        # "beta", from the Ioffe & Szegedy manuscript.
233
        if self.learn_shift:
234
            self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
235
                                            broadcastable=broadcastable)
236
            add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
237
            self.parameters.append(self.shift)
238
        else:
239
            self.shift = tensor.constant(0, dtype=theano.config.floatX)
240
241
        if self.learn_scale and not self.mean_only:
242
            # "gamma", from the Ioffe & Szegedy manuscript.
243
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
244
                                            broadcastable=broadcastable)
245
246
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
247
            self.parameters.append(self.scale)
248
        else:
249
            self.scale = tensor.constant(1., dtype=theano.config.floatX)
250
251
        self._allocate_population_statistics(var_dim, broadcastable)
252
253
    def _allocate_population_statistics(self, var_dim, broadcastable):
254
        def _allocate_buffer(name, role, value):
255
            # These aren't technically parameters, in that they should not be
256
            # learned using the same cost function as other model parameters.
257
            population_buffer = shared_floatx(value * numpy.ones(var_dim),
258
                                              broadcastable=broadcastable,
259
                                              name=name)
260
            add_role(population_buffer, role)
261
            # Normally these would get annotated by an AnnotatingList, but they
262
            # aren't in self.parameters.
263
            add_annotation(population_buffer, self)
264
            return population_buffer
265
266
        self.population_mean = _allocate_buffer('population_mean',
267
                                                BATCH_NORM_POPULATION_MEAN,
268
                                                numpy.zeros(var_dim))
269
270
        self.population_stdev = _allocate_buffer('population_stdev',
271
                                                 BATCH_NORM_POPULATION_STDEV,
272
                                                 numpy.ones(var_dim))
273
274
    def _initialize(self):
275
        # We gate with is_shared_variable rather than relying on
276
        # learn_scale and learn_shift so as to avoid the unlikely but nasty
277
        # scenario where those flags are changed post-allocation but
278
        # pre-initialization. This ensures that such a change simply has no
279
        # effect rather than doing an inconsistent combination of things.
280
        if is_shared_variable(self.shift):
281
            self.shift_init.initialize(self.shift, self.rng)
282
        if is_shared_variable(self.scale):
283
            self.scale_init.initialize(self.scale, self.rng)
284
285
    # Needed for the Feedforward interface.
286
    @property
287
    def output_dim(self):
288
        return self.input_dim
289
290
    # The following properties allow for BatchNormalization bricks
291
    # to be used directly inside of a ConvolutionalSequence.
292
    @property
293
    def image_size(self):
294
        return self.input_dim[-2:]
295
296
    @image_size.setter
297
    def image_size(self, value):
298
        if not isinstance(self.input_dim, collections.Sequence):
299
            self.input_dim = (None,) + tuple(value)
300
        else:
301
            self.input_dim = (self.input_dim[0],) + tuple(value)
302
303
    @property
304
    def num_channels(self):
305
        return self.input_dim[0]
306
307
    @num_channels.setter
308
    def num_channels(self, value):
309
        if not isinstance(self.input_dim, collections.Sequence):
310
            self.input_dim = (value,) + (None, None)
311
        else:
312
            self.input_dim = (value,) + self.input_dim[-2:]
313
314
    def get_dim(self, name):
315
        if name in ('input', 'output'):
316
            return self.input_dim
317
        else:
318
            raise KeyError
319
320
    @property
321
    def num_output_channels(self):
322
        return self.num_channels
323
324
325
class SpatialBatchNormalization(BatchNormalization):
326
    """Convenient subclass for batch normalization across spatial inputs.
327
328
    Parameters
329
    ----------
330
    input_dim : int or tuple
331
        The input size of a single example. Must be length at least 2.
332
        It's assumed that the first axis of this tuple is a "channels"
333
        axis, which should not be summed over, and all remaining
334
        dimensions are spatial dimensions.
335
336
    Notes
337
    -----
338
    See :class:`BatchNormalization` for more details (and additional
339
    keyword arguments).
340
341
    """
342
    def _allocate(self):
343
        if not isinstance(self.input_dim,
344
                          collections.Sequence) or len(self.input_dim) < 2:
345
            raise ValueError('expected input_dim to be length >= 2 '
346
                             'e.g. (channels, height, width)')
347
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
348
        super(SpatialBatchNormalization, self)._allocate()
349
350
351
class BatchNormalizedMLP(MLP):
352
    """Convenient subclass for building an MLP with batch normalization.
353
354
    Parameters
355
    ----------
356
    conserve_memory : bool, optional, by keyword only
357
        See :class:`BatchNormalization`.
358
    mean_only : bool, optional, by keyword only
359
        See :class:`BatchNormalization`.
360
    learn_scale : bool, optional, by keyword only
361
        See :class:`BatchNormalization`.
362
    learn_shift : bool, optional, by keyword only
363
        See :class:`BatchNormalization`.
364
365
    Notes
366
    -----
367
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
368
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
369
    containing an appropriate :class:`BatchNormalization` brick and
370
    the activation that follows it.
371
372
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
373
    not contain any biases, as they could be canceled out by the biases
374
    in the :class:`BatchNormalization` bricks being added. Pass
375
    `use_bias` with a value of `True` if you really want this for some
376
    reason.
377
378
    `mean_only`, `learn_scale` and `learn_shift` are pushed down to
379
    all created :class:`BatchNormalization` bricks as allocation
380
    config.
381
382
    """
383
    @lazy(allocation=['dims'])
384
    def __init__(self, activations, dims, *args, **kwargs):
385
        self._conserve_memory = kwargs.pop('conserve_memory', True)
386
        self.mean_only = kwargs.pop('mean_only', False)
387
        self.learn_scale = kwargs.pop('learn_scale', True)
388
        self.learn_shift = kwargs.pop('learn_shift', True)
389
390
        activations = [
391
            Sequence([
392
                (BatchNormalization(conserve_memory=self._conserve_memory)
393
                 .apply),
394
                act.apply
395
            ], name='batch_norm_activation_{}'.format(i))
396
            for i, act in enumerate(activations)
397
        ]
398
        # Batch normalization bricks incorporate a bias, so there's no
399
        # need for our Linear bricks to have them.
400
        kwargs.setdefault('use_bias', False)
401
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
402
                                                 **kwargs)
403
404
    def _nested_brick_property_getter(self, property_name):
405
        return getattr(self, '_' + property_name)
406
407
    def _nested_brick_property_setter(self, value, property_name):
408
        setattr(self, '_' + property_name, value)
409
        for act in self.activations:
410
            assert isinstance(act.children[0], BatchNormalization)
411
            setattr(act.children[0], property_name, value)
412
413
    # conserve_memory is a bit special in that it can be modified
414
    # after construction/allocation and still have a valid effect on
415
    # apply(). Thus we propagate down all property sets.
416
    conserve_memory = property(partial(_nested_brick_property_getter,
417
                                       property_name='conserve_memory'),
418
                               partial(_nested_brick_property_setter,
419
                                       property_name='conserve_memory'),
420
                               doc="Conserve memory.")
421
422
    def _push_allocation_config(self):
423
        super(BatchNormalizedMLP, self)._push_allocation_config()
424
        # Do the extra allocation pushing for the BatchNormalization
425
        # bricks. They need as their input dimension the output dimension
426
        # of each linear transformation.  Exclude the first dimension,
427
        # which is the input dimension.
428
        for act, dim in equizip(self.activations, self.dims[1:]):
429
            assert isinstance(act.children[0], BatchNormalization)
430
            act.children[0].input_dim = dim
431
            for attr in ['mean_only', 'learn_scale', 'learn_shift']:
432
                setattr(act.children[0], attr, getattr(self, attr))
433