Completed
Pull Request — master (#941)
by David
23:17 queued 14s
created

blocks.bricks.BatchNormalization._allocate()   C

Complexity

Conditions 7

Size

Total Lines 39

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 7
dl 0
loc 39
rs 5.5
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
import theano
6
from theano import tensor
7
from theano.tensor.nnet import bn
8
9
from ..graph import add_annotation
10
from ..initialization import Constant
11
from ..roles import (BATCH_NORM_POPULATION_MEAN,
12
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
13
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
14
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
15
                     add_role)
16
from ..utils import (shared_floatx_zeros, shared_floatx,
17
                     shared_floatx_nans)
18
from .base import lazy, application
19
from .sequences import Sequence, Feedforward, MLP
20
from .interfaces import RNGMixin
21
22
23
def _add_batch_axis(var):
24
    """Prepend a singleton axis to a TensorVariable and name it."""
25
    new_var = new_var = tensor.shape_padleft(var)
26
    new_var.name = 'shape_padleft({})'.format(var.name)
27
    return new_var
28
29
30
def _add_role_and_annotate(var, role, annotations=()):
31
    """Add a role and zero or more annotations to a variable."""
32
    add_role(var, role)
33
    for annotation in annotations:
34
        add_annotation(var, annotation)
35
36
37
class BatchNormalization(RNGMixin, Feedforward):
38
    r"""Normalizes activations, parameterizes a scale and shift.
39
40
    Parameters
41
    ----------
42
    input_dim : int or tuple
43
        Shape of a single input example. It is assumed that a batch axis
44
        will be prepended to this.
45
    broadcastable : tuple, optional
46
        Tuple the same length as `input_dim` which specifies which of the
47
        per-example axes should be averaged over to compute means and
48
        standard deviations. For example, in order to normalize over all
49
        spatial locations in a `(batch_index, channels, height, width)`
50
        image, pass `(False, True, True)`.
51
    conserve_memory : bool, optional
52
        Use an implementation that stores less intermediate state and
53
        therefore uses less memory, at the expense of 5-10% speed. Default
54
        is `True`.
55
    epsilon : float, optional
56
       The stabilizing constant for the minibatch standard deviation
57
       computation (when the brick is run in training mode).
58
       Added to the variance inside the square root, as in the
59
       batch normalization paper.
60
    scale_init : object, optional
61
        Initialization object to use for the learned scaling parameter
62
        ($\\gamma$ in [BN]_). By default, uses constant initialization
63
        of 1.
64
    shift_init : object, optional
65
        Initialization object to use for the learned shift parameter
66
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
67
68
    Notes
69
    -----
70
    In order for trained models to behave sensibly immediately upon
71
    upon deserialization, by default, this brick runs in *inference* mode,
72
    using a population mean and population standard deviation (initialized
73
    to zeros and ones respectively) to normalize activations. It is
74
    expected that the user will adapt these during training in some
75
    fashion, independently of the training objective, e.g. by taking a
76
    moving average of minibatch-wise statistics.
77
78
    In order to *train* with batch normalization, one must obtain a
79
    training graph by transforming the original inference graph. See
80
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
81
    transform graphs, and :func:`~blocks.graph.batch_normalization`
82
    for a context manager that may enable shorter compile times
83
    (every instance of :class:`BatchNormalization` is itself a context
84
    manager, entry into which causes applications to be in minibatch
85
    "training" mode, however it is usually more convenient to use
86
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
87
    for all of your graph's :class:`BatchNormalization` bricks at once).
88
89
    Note that training in inference mode should be avoided, as this
90
    brick introduces scales and shift parameters (tagged with the
91
    `PARAMETER` role) that, in the absence of batch normalization,
92
    usually makes things unstable. If you must do this, filter for and
93
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
94
    from the list of parameters you are training, and this brick should
95
    behave as a (somewhat expensive) no-op.
96
97
    This Brick accepts `scale_init` and `shift_init` arguments but is
98
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
99
    therefore not receive pushed initialization config from any parent
100
    brick. In almost all cases, you will probably want to stick with the
101
    defaults (unit scale and zero offset), but you can explicitly pass one
102
    or both initializers to override this.
103
104
    This has the necessary properties to be inserted into a
105
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
106
    the `input_dim` should be omitted at construction, to be inferred from
107
    the layer below.
108
109
    """
110
    @lazy(allocation=['input_dim'])
111
    def __init__(self, input_dim, broadcastable=None,
112
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
113
                 shift_init=None, **kwargs):
114
        self.input_dim = input_dim
115
        self.broadcastable = broadcastable
116
        self.conserve_memory = conserve_memory
117
        self.epsilon = epsilon
118
        self.scale_init = (Constant(1) if scale_init is None
119
                           else scale_init)
120
        self.shift_init = (Constant(0) if shift_init is None
121
                           else shift_init)
122
        self._training_mode = []
123
        super(BatchNormalization, self).__init__(**kwargs)
124
125
    @application(inputs=['input_'], outputs=['output'])
126
    def apply(self, input_, application_call):
127
        if self._training_mode:
128
            mean, stdev = self._compute_training_statistics(input_)
129
        else:
130
            mean, stdev = self._prepare_population_statistics()
131
        # Useful for filtration of calls that were already made in
132
        # training mode when doing graph transformations.
133
        # Very important to cast to bool, as self._training_mode is
134
        # normally a list (to support nested context managers), which would
135
        # otherwise get passed by reference and be remotely mutated.
136
        application_call.metadata['training_mode'] = bool(self._training_mode)
137
        # Useful for retrieving a list of updates for population
138
        # statistics. Ditch the broadcastable first axis, though, to
139
        # make it the same dimensions as the population mean and stdev
140
        # shared variables.
141
        application_call.metadata['offset'] = mean[0]
142
        application_call.metadata['divisor'] = stdev[0]
143
        # Give these quantities roles in the graph.
144
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
145
                               [self, application_call])
146
        _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
147
                               [self, application_call])
148
        scale = _add_batch_axis(self.scale)
149
        shift = _add_batch_axis(self.shift)
150
        # Heavy lifting is done by the Theano utility function.
151
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
152
                                            mode=('low_mem'
153
                                                  if self.conserve_memory
154
                                                  else 'high_mem'))
155
        return normalized
156
157
    def __enter__(self):
158
        self._training_mode.append(True)
159
160
    def __exit__(self, *exc_info):
161
        self._training_mode.pop()
162
163
    def _compute_training_statistics(self, input_):
164
        axes = (0,) + tuple((i + 1) for i, b in
165
                            enumerate(self.population_mean.broadcastable)
166
                            if b)
167
        mean = input_.mean(axis=axes, keepdims=True)
168
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
169
        stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
170
                            numpy.cast[theano.config.floatX](self.epsilon))
171
        assert stdev.broadcastable[1:] == self.population_stdev.broadcastable
172
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
173
        add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
174
        return mean, stdev
175
176
    def _prepare_population_statistics(self):
177
        mean = _add_batch_axis(self.population_mean)
178
        stdev = _add_batch_axis(self.population_stdev)
179
        return mean, stdev
180
181
    def _allocate(self):
182
        input_dim = ((self.input_dim,)
183
                     if not isinstance(self.input_dim, collections.Sequence)
184
                     else self.input_dim)
185
        broadcastable = (tuple(False for _ in input_dim)
186
                         if self.broadcastable is None else self.broadcastable)
187
        if len(input_dim) != len(broadcastable):
188
            raise ValueError("input_dim and broadcastable must be same length")
189
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
190
                        equizip(input_dim, broadcastable))
191
        broadcastable = broadcastable
192
193
        # "gamma", from the Ioffe & Szegedy manuscript.
194
        self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
195
                                        broadcastable=broadcastable)
196
197
        # "beta", from the Ioffe & Szegedy manuscript.
198
        self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
199
                                        broadcastable=broadcastable)
200
        add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
201
        add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
202
        self.parameters.append(self.scale)
203
        self.parameters.append(self.shift)
204
205
        # These aren't technically parameters, in that they should not be
206
        # learned using the same cost function as other model parameters.
207
        self.population_mean = shared_floatx_zeros(var_dim,
208
                                                   name='population_mean',
209
                                                   broadcastable=broadcastable)
210
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
211
                                              name='population_stdev',
212
                                              broadcastable=broadcastable)
213
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
214
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
215
216
        # Normally these would get annotated by an AnnotatingList, but they
217
        # aren't in self.parameters.
218
        add_annotation(self.population_mean, self)
219
        add_annotation(self.population_stdev, self)
220
221
    def _initialize(self):
222
        self.shift_init.initialize(self.shift, self.rng)
223
        self.scale_init.initialize(self.scale, self.rng)
224
225
    # Needed for the Feedforward interface.
226
    @property
227
    def output_dim(self):
228
        return self.input_dim
229
230
    # The following properties allow for BatchNormalization bricks
231
    # to be used directly inside of a ConvolutionalSequence.
232
    @property
233
    def image_size(self):
234
        return self.input_dim[-2:]
235
236
    @image_size.setter
237
    def image_size(self, value):
238
        if not isinstance(self.input_dim, collections.Sequence):
239
            self.input_dim = (None,) + tuple(value)
240
        else:
241
            self.input_dim = (self.input_dim[0],) + tuple(value)
242
243
    @property
244
    def num_channels(self):
245
        return self.input_dim[0]
246
247
    @num_channels.setter
248
    def num_channels(self, value):
249
        if not isinstance(self.input_dim, collections.Sequence):
250
            self.input_dim = (value,) + (None, None)
251
        else:
252
            self.input_dim = (value,) + self.input_dim[-2:]
253
254
    def get_dim(self, name):
255
        if name in ('input', 'output'):
256
            return self.input_dim
257
        else:
258
            raise KeyError
259
260
    @property
261
    def num_output_channels(self):
262
        return self.num_channels
263
264
265
class SpatialBatchNormalization(BatchNormalization):
266
    """Convenient subclass for batch normalization across spatial inputs.
267
268
    Parameters
269
    ----------
270
    input_dim : int or tuple
271
        The input size of a single example. Must be length at least 2.
272
        It's assumed that the first axis of this tuple is a "channels"
273
        axis, which should not be summed over, and all remaining
274
        dimensions are spatial dimensions.
275
276
    Notes
277
    -----
278
    See :class:`BatchNormalization` for more details (and additional
279
    keyword arguments).
280
281
    """
282
    @lazy(allocation=['input_dim'])
283
    def __init__(self, input_dim, **kwargs):
284
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
285
        kwargs.setdefault('broadcastable', broadcastable)
286
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
287
288
    def _allocate(self):
289
        if not isinstance(self.input_dim,
290
                          collections.Sequence) or len(input_dim) < 2:
0 ignored issues
show
Comprehensibility Best Practice introduced by
Undefined variable 'input_dim'
Loading history...
291
            raise ValueError('expected input_dim to be length >= 2 '
292
                             'e.g. (channels, height, width)')
293
        super(SpatialBatchNormalization, self)._allocate()
294
295
296
class BatchNormalizedMLP(MLP):
297
    """Convenient subclass for building an MLP with batch normalization.
298
299
    Parameters
300
    ----------
301
    conserve_memory : bool, optional
302
        See :class:`BatchNormalization`.
303
304
    Notes
305
    -----
306
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
307
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
308
    containing an appropriate :class:`BatchNormalization` brick and
309
    the activation that follows it.
310
311
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
312
    not contain any biases, as they could be canceled out by the biases
313
    in the :class:`BatchNormalization` bricks being added. Pass
314
    `use_bias` with a value of `True` if you really want this for some
315
    reason.
316
317
    """
318
    @lazy(allocation=['dims'])
319
    def __init__(self, activations, dims, *args, **kwargs):
320
        conserve_memory = kwargs.pop('conserve_memory', True)
321
        activations = [
322
            Sequence([
323
                BatchNormalization(conserve_memory=conserve_memory).apply,
324
                act.apply
325
            ], name='batch_norm_activation_{}'.format(i))
326
            for i, act in enumerate(activations)
327
        ]
328
        # Batch normalization bricks incorporate a bias, so there's no
329
        # need for our Linear bricks to have them.
330
        kwargs.setdefault('use_bias', False)
331
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
332
                                                 **kwargs)
333
334
    @property
335
    def conserve_memory(self):
336
        return self._conserve_memory
337
338
    @conserve_memory.setter
339
    def conserve_memory(self, value):
340
        self._conserve_memory = value
341
        for act in self.activations:
342
            assert isinstance(act.children[0], BatchNormalization)
343
            act.children[0].conserve_memory = value
344
345
    def _push_allocation_config(self):
346
        super(BatchNormalizedMLP, self)._push_allocation_config()
347
        # Do the extra allocation pushing for the BatchNormalization
348
        # bricks. They need as their input dimension the output dimension
349
        # of each linear transformation.  Exclude the first dimension,
350
        # which is the input dimension.
351
        for act, dim in equizip(self.activations, self.dims[1:]):
352
            assert isinstance(act.children[0], BatchNormalization)
353
            act.children[0].input_dim = dim
354