Completed
Pull Request — master (#941)
by David
03:55 queued 02:23
created

blocks.bricks.SpatialBatchNormalization   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 26
Duplicated Lines 0 %
Metric Value
dl 0
loc 26
rs 10
wmc 3

1 Method

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 9 3
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
import theano
6
from theano import tensor
7
from theano.tensor.nnet import bn
8
9
from ..graph import add_annotation
10
from ..initialization import Constant
11
from ..roles import (WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN,
12
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
13
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
14
                     add_role)
15
from ..utils import (shared_floatx_zeros, shared_floatx,
16
                     shared_floatx_nans)
17
from .base import lazy, application
18
from .sequences import Sequence, Feedforward, MLP
19
from .interfaces import RNGMixin
20
21
22
def _add_batch_axis(var, name=None):
23
    """Prepend a singleton axis to a TensorVariable."""
24
    new_var = var.dimshuffle('x', *list(range(var.ndim)))
25
    new_var.name = name
26
    return new_var
27
28
29
def _add_role_and_annotate(var, role, annotations=()):
30
    """Add a role and zero or more annotations to a variable."""
31
    add_role(var, role)
32
    for annotation in annotations:
33
        add_annotation(var, annotation)
34
35
36
class BatchNormalization(RNGMixin, Feedforward):
37
    r"""Normalizes activations, parameterizes a scale and shift.
38
39
    Parameters
40
    ----------
41
    input_dim : int or tuple
42
        Shape of a single input example. It is assumed that a batch axis
43
        will be prepended to this.
44
    broadcastable : tuple, optional
45
        Tuple the same length as `input_dim` which specifies which of the
46
        per-example axes should be averaged over to compute means and
47
        standard deviations. For example, in order to normalize over all
48
        spatial locations in a `(batch_index, channels, height, width)`
49
        image, pass `(False, True, True)`.
50
    conserve_memory : bool, optional
51
        Use an implementation that stores less intermediate state and
52
        therefore uses less memory, at the expense of 5-10% speed. Default
53
        is `True`.
54
    epsilon : float, optional
55
       The stabilizing constant for the minibatch standard deviation
56
       computation (when the brick is run in training mode).
57
       Added to the variance inside the square root, as in the
58
       batch normalization paper.
59
    weights_init : object, optional
60
        Initialization object to use for the learned scaling parameter
61
        ($\\gamma$ in [BN]_). By default, uses constant initialization
62
        of 1.
63
    biases_init : object, optional
64
        Initialization object to use for the learned shift parameter
65
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
66
67
    Notes
68
    -----
69
    In order for trained models to behave sensibly immediately upon
70
    upon deserialization, by default, this brick runs in *inference* mode,
71
    using a population mean and population standard deviation (initialized
72
    to zeros and ones respectively) to normalize activations. It is
73
    expected that the user will adapt these during training in some
74
    fashion, independently of the training objective, e.g. by taking a
75
    moving average of minibatch-wise statistics.
76
77
    In order to *train* with batch normalization, one must obtain a
78
    training graph by transforming the original inference graph. See
79
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
80
    transform graphs, and :func:`~blocks.graph.batch_normalization`
81
    for a context manager that may enable shorter compile times
82
    (every instance of :class:`BatchNormalization` is itself a context
83
    manager, entry into which causes applications to be in minibatch
84
    "training" mode, however it is usually more convenient to use
85
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
86
    for all of your graph's :class:`BatchNormalization` bricks at once).
87
88
    Note that training in inference mode should be avoided, as this
89
    brick introduces scales and shift parameters (tagged with the
90
    `PARAMETER` role) that, in the absence of batch normalization,
91
    usually makes things unstable. If you must do this, filter for and
92
    remove `BATCH_NORM_SHIFT` and `BATCH_NORM_SCALE` from the list of
93
    parameters you are training, and this brick should behave as a
94
    (somewhat expensive) no-op.
95
96
    This Brick accepts `weights_init` and `biases_init` arguments but is
97
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
98
    therefore not receive pushed initialization config from any parent
99
    brick. In almost all cases, you will probably want to stick with the
100
    defaults (unit scale and zero offset), but you can explicitly pass one
101
    or both initializers to override this.
102
103
    This has the necessary properties to be inserted into a
104
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
105
    the `input_dim` should be omitted at construction, to be inferred from
106
    the layer below.
107
108
    """
109
    @lazy(allocation=['input_dim'])
110
    def __init__(self, input_dim, broadcastable=None,
111
                 conserve_memory=True, epsilon=1e-4, weights_init=None,
0 ignored issues
show
Unused Code introduced by
The argument epsilon seems to be unused.
Loading history...
112
                 biases_init=None, **kwargs):
113
        self.input_dim = input_dim
114
        self.broadcastable = broadcastable
115
        self.conserve_memory = conserve_memory
116
        self.epsilon = 1e-4
117
        self.weights_init = (Constant(1) if weights_init is None
118
                             else weights_init)
119
        self.biases_init = (Constant(0) if biases_init is None
120
                            else biases_init)
121
        self._training_mode = False
122
        super(BatchNormalization, self).__init__(**kwargs)
123
124
    @application(inputs=['input_'], outputs=['output'])
125
    def apply(self, input_, application_call):
126
        if self._training_mode:
127
            mean, stdev = self._compute_training_statistics(input_)
128
        else:
129
            mean, stdev = self._prepare_population_statistics()
130
        # Useful for filtration of calls that were already made in
131
        # training mode when doing graph transformations.
132
        application_call.metadata['training_mode'] = self._training_mode
133
        # Useful for retrieving a list of updates for population
134
        # statistics. Ditch the broadcastable first axis, though, to
135
        # make it the same dimensions as the population mean and stdev
136
        # shared variables.
137
        application_call.metadata['offset'] = mean[0]
138
        application_call.metadata['divisor'] = stdev[0]
139
        # Give these quantities roles in the graph.
140
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
141
                               [self, application_call])
142
        _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
143
                               [self, application_call])
144
        W = _add_batch_axis(self.W, "W.dimshuffle('x'...)")
145
        b = _add_batch_axis(self.b, "b.dimshuffle('x', ...)")
146
        # Heavy lifting is done by the Theano utility function.
147
        normalized = bn.batch_normalization(input_, W, b, mean, stdev,
148
                                            mode=('low_mem'
149
                                                  if self.conserve_memory
150
                                                  else 'high_mem'))
151
        return normalized
152
153
    def __enter__(self):
154
        self._training_mode = True
155
156
    def __exit__(self, *exc_info):
157
        self._training_mode = False
158
159
    def _compute_training_statistics(self, input_):
160
        axes = (0,) + tuple((i + 1) for i, b in
161
                            enumerate(self.population_mean.broadcastable)
162
                            if b)
163
        mean = input_.mean(axis=axes, keepdims=True)
164
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
165
        stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
166
                            numpy.cast[theano.config.floatX](self.epsilon))
167
        assert stdev.broadcastable[1:] == self.population_stdev.broadcastable
168
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
169
        add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
170
        return mean, stdev
171
172
    def _prepare_population_statistics(self):
173
        mean = _add_batch_axis(self.population_mean, 'population_offset')
174
        stdev = _add_batch_axis(self.population_stdev, 'population_divisor')
175
        return mean, stdev
176
177
    def _allocate(self):
178
        input_dim = ((self.input_dim,)
179
                     if not isinstance(self.input_dim, collections.Sequence)
180
                     else self.input_dim)
181
        broadcastable = (tuple(False for _ in input_dim)
182
                         if self.broadcastable is None else self.broadcastable)
183
        if len(input_dim) != len(broadcastable):
184
            raise ValueError("input_dim and broadcastable must be same length")
185
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
186
                        equizip(input_dim, broadcastable))
187
        broadcastable = broadcastable
188
189
        # "gamma", from the Ioffe & Szegedy manuscript.
190
        self._W = shared_floatx_nans(var_dim, name='batch_norm_scale',
191
                                     broadcastable=broadcastable)
192
193
        # "beta", from the Ioffe & Szegedy manuscript.
194
        self._b = shared_floatx_nans(var_dim, name='batch_norm_shift',
195
                                     broadcastable=broadcastable)
196
        add_role(self.W, WEIGHT)
197
        add_role(self.b, BIAS)
198
        self.parameters.append(self.W)
199
        self.parameters.append(self.b)
200
201
        # These aren't technically parameters, in that they should not be
202
        # learned using the same cost function as other model parameters.
203
        self.population_mean = shared_floatx_zeros(var_dim,
204
                                                   name='population_mean',
205
                                                   broadcastable=broadcastable)
206
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
207
                                              name='population_stdev',
208
                                              broadcastable=broadcastable)
209
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
210
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
211
212
    @property
213
    def W(self):
214
        return self._W
215
216
    @property
217
    def b(self):
218
        return self._b
219
220
    def _initialize(self):
221
        self.biases_init.initialize(self.b, self.rng)
222
        self.weights_init.initialize(self.W, self.rng)
223
224
    # Needed for the Feedforward interface.
225
    @property
226
    def output_dim(self):
227
        return self.input_dim
228
229
    # The following properties allow for BatchNormalization bricks
230
    # to be used directly inside of a ConvolutionalSequence.
231
    @property
232
    def image_size(self):
233
        return self.input_dim[-2:]
234
235
    @image_size.setter
236
    def image_size(self, value):
237
        if not isinstance(self.input_dim, collections.Sequence):
238
            self.input_dim = (None,) + tuple(value)
239
        else:
240
            self.input_dim = (self.input_dim[0],) + tuple(value)
241
242
    @property
243
    def num_channels(self):
244
        return self.input_dim[0]
245
246
    @num_channels.setter
247
    def num_channels(self, value):
248
        if not isinstance(self.input_dim, collections.Sequence):
249
            self.input_dim = (value,) + (None, None)
250
        else:
251
            self.input_dim = (value,) + self.input_dim[-2:]
252
253
    def get_dim(self, name):
254
        if name in ('input', 'output'):
255
            return self.input_dim
256
        else:
257
            raise KeyError
258
259
    @property
260
    def num_output_channels(self):
261
        return self.num_channels
262
263
264
class SpatialBatchNormalization(BatchNormalization):
265
    """Convenient subclass for batch normalization across spatial inputs.
266
267
    Parameters
268
    ----------
269
    input_dim : int or tuple
270
        The input size of a single example. Must be length at least 2.
271
        It's assumed that the first axis of this tuple is a "channels"
272
        axis, which should not be summed over, and all remaining
273
        dimensions are spatial dimensions.
274
275
    Notes
276
    -----
277
    See :class:`BatchNormalization` for more details (and additional
278
    keyword arguments).
279
280
    """
281
    @lazy(allocation=['input_dim'])
282
    def __init__(self, input_dim, **kwargs):
283
        if not isinstance(input_dim,
284
                          collections.Sequence) or len(input_dim) < 2:
285
            raise ValueError('expected input_dim to be length >= 2 '
286
                             'e.g. (channels, height, width)')
287
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
288
        kwargs.setdefault('broadcastable', broadcastable)
289
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
290
291
292
class BatchNormalizedMLP(MLP):
293
    """Convenient subclass for building an MLP with batch normalization.
294
295
    Parameters
296
    ----------
297
    conserve_memory : bool, optional
298
        See :class:`BatchNormalization`.
299
300
    Notes
301
    -----
302
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
303
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
304
    containing an appropriate :class:`BatchNormalization` brick and
305
    the activation that follows it.
306
307
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
308
    not contain any biases, as they could be canceled out by the biases
309
    in the :class:`BatchNormalization` bricks being added. Pass
310
    `use_bias` with a value of `True` if you really want this for some
311
    reason.
312
313
    """
314
    @lazy(allocation=['dims'])
315
    def __init__(self, activations, dims, *args, **kwargs):
316
        conserve_memory = kwargs.pop('conserve_memory', True)
317
        activations = [
318
            Sequence([BatchNormalization(conserve_memory=conserve_memory).apply,
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (80/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
319
                      act.apply], name='batch_norm_activation_{}'.format(i))
320
            for i, act in enumerate(activations)
321
        ]
322
        # Batch normalization bricks incorporate a bias, so there's no
323
        # need for our Linear bricks to have them.
324
        kwargs.setdefault('use_bias', False)
325
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
326
                                                 **kwargs)
327
328
    @property
329
    def conserve_memory(self):
330
        return self._conserve_memory
331
332
    @conserve_memory.setter
333
    def conserve_memory(self, value):
334
        self._conserve_memory = value
335
        for act in self.activations:
336
            assert isinstance(act.children[0], BatchNormalization)
337
            act.children[0].conserve_memory = value
338
339
    def _push_allocation_config(self):
340
        super(BatchNormalizedMLP, self)._push_allocation_config()
341
        # Do the extra allocation pushing for the BatchNormalization
342
        # bricks. They need as their input dimension the output dimension
343
        # of each linear transformation.  Exclude the first dimension,
344
        # which is the input dimension.
345
        for act, dim in equizip(self.activations, self.dims[1:]):
346
            assert isinstance(act.children[0], BatchNormalization)
347
            act.children[0].input_dim = dim
348