Completed
Pull Request — master (#941)
by David
01:52
created

blocks.bricks.BatchNormalization.W()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
import theano
6
from theano import tensor
7
from theano.tensor.nnet import bn
8
9
from ..graph import add_annotation
10
from ..initialization import Constant
11
from ..roles import (WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN,
12
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
13
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
14
                     add_role)
15
from ..utils import (shared_floatx_zeros, shared_floatx,
16
                     shared_floatx_nans)
17
from .base import lazy, application
18
from .sequences import Sequence, Feedforward, MLP
19
from .interfaces import RNGMixin
20
21
22
def _add_batch_axis(var, name=None):
23
    """Prepend a singleton axis to a TensorVariable."""
24
    new_var = var.dimshuffle('x', *list(range(var.ndim)))
25
    new_var.name = name
26
    return new_var
27
28
29
def _add_role_and_annotate(var, role, annotations=()):
30
    """Add a role and zero or more annotations to a variable."""
31
    add_role(var, role)
32
    for annotation in annotations:
33
        add_annotation(var, annotation)
34
35
36
class BatchNormalization(RNGMixin, Feedforward):
37
    r"""Normalizes activations, parameterizes a scale and shift.
38
39
    Parameters
40
    ----------
41
    input_dim : int or tuple
42
        Shape of a single input example. It is assumed that a batch axis
43
        will be prepended to this.
44
    broadcastable : tuple, optional
45
        Tuple the same length as `input_dim` which specifies which of the
46
        per-example axes should be averaged over to compute means and
47
        standard deviations. For example, in order to normalize over all
48
        spatial locations in a `(batch_index, channels, height, width)`
49
        image, pass `(False, True, True)`.
50
    save_memory : bool, optional
51
        Use an implementation that stores less intermediate state and
52
        therefore uses less memory, at the expense of 5-10% speed. Default
53
        is `True`.
54
    epsilon : float, optional
55
       The stabilizing constant for the minibatch standard deviation
56
       computation (when the brick is run in training mode).
57
       Added to the variance inside the square root, as in the
58
       batch normalization paper.
59
    weights_init : object, optional
60
        Initialization object to use for the learned scaling parameter
61
        ($\\gamma$ in [BN]_). By default, uses constant initialization
62
        of 1.
63
    biases_init : object, optional
64
        Initialization object to use for the learned shift parameter
65
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
66
67
    Notes
68
    -----
69
    In order for trained models to behave sensibly immediately upon
70
    upon deserialization, by default, this brick runs in *inference* mode,
71
    using a population mean and population standard deviation (initialized
72
    to zeros and ones respectively) to normalize activations. It is
73
    expected that the user will adapt these during training in some
74
    fashion, independently of the training objective, e.g. by taking a
75
    moving average of minibatch-wise statistics.
76
77
    In order to *train* with batch normalization, one must obtain a
78
    training graph by transforming the original inference graph. See
79
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
80
    transform graphs, and :func:`~blocks.graph.batch_normalization`
81
    for a context manager that may enable shorter compile times
82
    (every instance of :class:`BatchNormalization` is itself a context
83
    manager, entry into which causes applications to be in minibatch
84
    "training" mode, however it is usually more convenient to use
85
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
86
    for all of your graph's :class:`BatchNormalization` bricks at once).
87
88
    Note that training in inference mode should be avoided, as this
89
    brick introduces scales and shift parameters (tagged with the
90
    `PARAMETER` role) that, in the absence of batch normalization,
91
    usually makes things unstable. If you must do this, filter for and
92
    remove `BATCH_NORM_SHIFT` and `BATCH_NORM_SCALE` from the list of
93
    parameters you are training, and this brick should behave as a
94
    (somewhat expensive) no-op.
95
96
    This Brick accepts `weights_init` and `biases_init` arguments but is
97
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
98
    therefore not receive pushed initialization config from any parent
99
    brick. In almost all cases, you will probably want to stick with the
100
    defaults (unit scale and zero offset), but you can explicitly pass one
101
    or both initializers to override this.
102
103
    This has the necessary properties to be inserted into a
104
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
105
    the `input_dim` should be omitted at construction, to be inferred from
106
    the layer below.
107
108
    """
109
    @lazy(allocation=['input_dim'])
110
    def __init__(self, input_dim, broadcastable=None,
111
                 save_memory=True, epsilon=1e-4, weights_init=None,
0 ignored issues
show
Unused Code introduced by
The argument epsilon seems to be unused.
Loading history...
112
                 biases_init=None, **kwargs):
113
        self.input_dim = input_dim
114
        self.broadcastable = broadcastable
115
        self.save_memory = save_memory
116
        self.epsilon = 1e-4
117
        self.weights_init = (Constant(1) if weights_init is None
118
                             else weights_init)
119
        self.biases_init = (Constant(0) if biases_init is None
120
                            else biases_init)
121
        self._training_mode = False
122
        super(BatchNormalization, self).__init__(**kwargs)
123
124
    @application(inputs=['input_'], outputs=['output'])
125
    def apply(self, input_, application_call):
126
        if self._training_mode:
127
            mean, stdev = self._compute_training_statistics(input_)
128
        else:
129
            mean, stdev = self._prepare_population_statistics()
130
        # Useful for filtration of calls that were already made in
131
        # training mode when doing graph transformations.
132
        application_call.metadata['training_mode'] = self._training_mode
133
        # Useful for retrieving a list of updates for population
134
        # statistics. Ditch the broadcastable first axis, though, to
135
        # make it the same dimensions as the population mean and stdev
136
        # shared variables.
137
        application_call.metadata['offset'] = mean[0]
138
        application_call.metadata['divisor'] = stdev[0]
139
        # Give these quantities roles in the graph.
140
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
141
                               [self, application_call])
142
        _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
143
                               [self, application_call])
144
        W = _add_batch_axis(self.W, "W.dimshuffle('x'...)")
145
        b = _add_batch_axis(self.b, "b.dimshuffle('x', ...)")
146
        # Heavy lifting is done by the Theano utility function.
147
        normalized = bn.batch_normalization(input_, W, b, mean, stdev,
148
                                            mode=('low_mem' if self.save_memory
149
                                                  else 'high_mem'))
150
        return normalized
151
152
    def __enter__(self):
153
        self._training_mode = True
154
155
    def __exit__(self, *exc_info):
156
        self._training_mode = False
157
158
    def _compute_training_statistics(self, input_):
159
        axes = (0,) + tuple((i + 1) for i, b in
160
                            enumerate(self.population_mean.broadcastable)
161
                            if b)
162
        mean = input_.mean(axis=axes, keepdims=True)
163
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
164
        stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
165
                            numpy.cast[theano.config.floatX](self.epsilon))
166
        assert stdev.broadcastable[1:] == self.population_stdev.broadcastable
167
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
168
        add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
169
        return mean, stdev
170
171
    def _prepare_population_statistics(self):
172
        mean = _add_batch_axis(self.population_mean, 'population_offset')
173
        stdev = _add_batch_axis(self.population_stdev, 'population_divisor')
174
        return mean, stdev
175
176
    def _allocate(self):
177
        input_dim = ((self.input_dim,)
178
                     if not isinstance(self.input_dim, collections.Sequence)
179
                     else self.input_dim)
180
        broadcastable = (tuple(False for _ in range(len(input_dim)))
181
                         if self.broadcastable is None else self.broadcastable)
182
        if len(input_dim) != len(broadcastable):
183
            raise ValueError("input_dim and broadcastable must be same length")
184
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
185
                        equizip(input_dim, broadcastable))
186
        broadcastable = broadcastable
187
188
        # "gamma", from the Ioffe & Szegedy manuscript.
189
        self._W = shared_floatx_nans(var_dim, name='batch_norm_scale',
190
                                     broadcastable=broadcastable)
191
192
        # "beta", from the Ioffe & Szegedy manuscript.
193
        self._b = shared_floatx_nans(var_dim, name='batch_norm_shift',
194
                                     broadcastable=broadcastable)
195
        add_role(self.W, WEIGHT)
196
        add_role(self.b, BIAS)
197
        self.parameters.append(self.W)
198
        self.parameters.append(self.b)
199
200
        # These aren't technically parameters, in that they should not be
201
        # learned using the same cost function as other model parameters.
202
        self.population_mean = shared_floatx_zeros(var_dim,
203
                                                   name='population_mean',
204
                                                   broadcastable=broadcastable)
205
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
206
                                              name='population_stdev',
207
                                              broadcastable=broadcastable)
208
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
209
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
210
211
    @property
212
    def W(self):
213
        return self._W
214
215
    @property
216
    def b(self):
217
        return self._b
218
219
    def _initialize(self):
220
        self.biases_init.initialize(self.b, self.rng)
221
        self.weights_init.initialize(self.W, self.rng)
222
223
    # Needed for the Feedforward interface.
224
    @property
225
    def output_dim(self):
226
        return self.input_dim
227
228
    # The following properties allow for BatchNormalization bricks
229
    # to be used directly inside of a ConvolutionalSequence.
230
    @property
231
    def image_size(self):
232
        return self.input_dim[-2:]
233
234
    @image_size.setter
235
    def image_size(self, value):
236
        if not isinstance(self.input_dim, collections.Sequence):
237
            self.input_dim = (None,) + tuple(value)
238
        else:
239
            self.input_dim = (self.input_dim[0],) + tuple(value)
240
241
    @property
242
    def num_channels(self):
243
        return self.input_dim[0]
244
245
    @num_channels.setter
246
    def num_channels(self, value):
247
        if not isinstance(self.input_dim, collections.Sequence):
248
            self.input_dim = (value,) + (None, None)
249
        else:
250
            self.input_dim = (value,) + self.input_dim[-2:]
251
252
    def get_dim(self, name):
253
        if name in ('input', 'output'):
254
            return self.input_dim
255
        else:
256
            raise KeyError
257
258
    @property
259
    def num_output_channels(self):
260
        return self.num_channels
261
262
263
class SpatialBatchNormalization(BatchNormalization):
264
    """Convenient subclass for batch normalization across spatial inputs.
265
266
    Parameters
267
    ----------
268
    input_dim : int or tuple
269
        The input size of a single example. Must be length at least 2.
270
        It's assumed that the first axis of this tuple is a "channels"
271
        axis, which should not be summed over, and all remaining
272
        dimensions are spatial dimensions.
273
274
    Notes
275
    -----
276
    See :class:`BatchNormalization` for more details (and additional
277
    keyword arguments).
278
279
    """
280
    @lazy(allocation=['input_dim'])
281
    def __init__(self, input_dim, **kwargs):
282
        if not isinstance(input_dim,
283
                          collections.Sequence) or len(input_dim) < 2:
284
            raise ValueError('expected input_dim to be length >= 2 '
285
                             '(channels, height, width)')
286
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
287
        kwargs.setdefault('broadcastable', broadcastable)
288
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
289
290
291
class BatchNormalizedMLP(MLP):
292
    """Convenient subclass for building an MLP with batch normalization.
293
294
    Parameters
295
    ----------
296
    save_memory : bool, optional
297
        See :class:`BatchNormalization`.
298
299
    Notes
300
    -----
301
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
302
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
303
    containing an appropriate :class:`BatchNormalization` brick and
304
    the activation that follows it.
305
306
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
307
    not contain any biases, as they could be canceled out by the biases
308
    in the :class:`BatchNormalization` bricks being added. Pass
309
    `use_bias` with a value of `True` if you really want this for some
310
    reason.
311
312
    """
313
    @lazy(allocation=['dims'])
314
    def __init__(self, activations, dims, *args, **kwargs):
315
        save_memory = kwargs.pop('save_memory', True)
316
        activations = [
317
            Sequence([BatchNormalization(save_memory=save_memory).apply,
318
                      act.apply], name='batch_norm_activation_{}'.format(i))
319
            for i, act in enumerate(activations)
320
        ]
321
        # Batch normalization bricks incorporate a bias, so there's no
322
        # need for our Linear bricks to have them.
323
        kwargs.setdefault('use_bias', False)
324
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
325
                                                 **kwargs)
326
327
    @property
328
    def save_memory(self):
329
        return self._save_memory
330
331
    @save_memory.setter
332
    def save_memory(self, value):
333
        self._save_memory = value
334
        for act in self.activations:
335
            assert isinstance(act.children[0], BatchNormalization)
336
            act.children[0].save_memory = value
337
338
    def _push_allocation_config(self):
339
        super(BatchNormalizedMLP, self)._push_allocation_config()
340
        # Do the extra allocation pushing for the BatchNormalization
341
        # bricks. They need as their input dimension the output dimension
342
        # of each linear transformation.  Exclude the first dimension,
343
        # which is the input dimension.
344
        for act, dim in equizip(self.activations, self.dims[1:]):
345
            act.children[0].input_dim = dim
346