Completed
Push — master ( 568e7a...47285e )
by Vincent
25s
created

_nested_brick_property_setter()   A

Complexity

Conditions 3

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 5
rs 9.4285
1
import collections
2
from functools import partial
3
4
import numpy
5
from picklable_itertools.extras import equizip
6
import theano
7
from theano import tensor
8
from theano.tensor.nnet import bn
9
10
from ..graph import add_annotation
11
from ..initialization import Constant
12
from ..roles import (BATCH_NORM_POPULATION_MEAN,
13
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
14
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
15
                     BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER,
16
                     add_role)
17
from ..utils import (shared_floatx_zeros, shared_floatx,
18
                     shared_floatx_nans)
19
from .base import lazy, application
20
from .sequences import Sequence, Feedforward, MLP
21
from .interfaces import RNGMixin
22
23
24
def _add_batch_axis(var):
25
    """Prepend a singleton axis to a TensorVariable and name it."""
26
    new_var = new_var = tensor.shape_padleft(var)
27
    new_var.name = 'shape_padleft({})'.format(var.name)
28
    return new_var
29
30
31
def _add_role_and_annotate(var, role, annotations=()):
32
    """Add a role and zero or more annotations to a variable."""
33
    add_role(var, role)
34
    for annotation in annotations:
35
        add_annotation(var, annotation)
36
37
38
class BatchNormalization(RNGMixin, Feedforward):
39
    r"""Normalizes activations, parameterizes a scale and shift.
40
41
    Parameters
42
    ----------
43
    input_dim : int or tuple
44
        Shape of a single input example. It is assumed that a batch axis
45
        will be prepended to this.
46
    broadcastable : tuple, optional
47
        Tuple the same length as `input_dim` which specifies which of the
48
        per-example axes should be averaged over to compute means and
49
        standard deviations. For example, in order to normalize over all
50
        spatial locations in a `(batch_index, channels, height, width)`
51
        image, pass `(False, True, True)`.
52
    conserve_memory : bool, optional
53
        Use an implementation that stores less intermediate state and
54
        therefore uses less memory, at the expense of 5-10% speed. Default
55
        is `True`.
56
    epsilon : float, optional
57
       The stabilizing constant for the minibatch standard deviation
58
       computation (when the brick is run in training mode).
59
       Added to the variance inside the square root, as in the
60
       batch normalization paper.
61
    scale_init : object, optional
62
        Initialization object to use for the learned scaling parameter
63
        ($\\gamma$ in [BN]_). By default, uses constant initialization
64
        of 1.
65
    shift_init : object, optional
66
        Initialization object to use for the learned shift parameter
67
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
68
    mean_only : bool, optional
69
        Perform "mean-only" batch normalization as described in [SK2016]_.
70
71
    Notes
72
    -----
73
    In order for trained models to behave sensibly immediately upon
74
    upon deserialization, by default, this brick runs in *inference* mode,
75
    using a population mean and population standard deviation (initialized
76
    to zeros and ones respectively) to normalize activations. It is
77
    expected that the user will adapt these during training in some
78
    fashion, independently of the training objective, e.g. by taking a
79
    moving average of minibatch-wise statistics.
80
81
    In order to *train* with batch normalization, one must obtain a
82
    training graph by transforming the original inference graph. See
83
    :func:`~blocks.graph.apply_batch_normalization` for a routine to
84
    transform graphs, and :func:`~blocks.graph.batch_normalization`
85
    for a context manager that may enable shorter compile times
86
    (every instance of :class:`BatchNormalization` is itself a context
87
    manager, entry into which causes applications to be in minibatch
88
    "training" mode, however it is usually more convenient to use
89
    :func:`~blocks.graph.batch_normalization` to enable this behaviour
90
    for all of your graph's :class:`BatchNormalization` bricks at once).
91
92
    Note that training in inference mode should be avoided, as this
93
    brick introduces scales and shift parameters (tagged with the
94
    `PARAMETER` role) that, in the absence of batch normalization,
95
    usually makes things unstable. If you must do this, filter for and
96
    remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER`
97
    from the list of parameters you are training, and this brick should
98
    behave as a (somewhat expensive) no-op.
99
100
    This Brick accepts `scale_init` and `shift_init` arguments but is
101
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
102
    therefore not receive pushed initialization config from any parent
103
    brick. In almost all cases, you will probably want to stick with the
104
    defaults (unit scale and zero offset), but you can explicitly pass one
105
    or both initializers to override this.
106
107
    This has the necessary properties to be inserted into a
108
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
109
    the `input_dim` should be omitted at construction, to be inferred from
110
    the layer below.
111
112
113
    .. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization:
114
       accelerating deep network training by reducing internal covariate
115
       shift*. ICML (2015), pp. 448-456.
116
117
    .. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight
118
       normalization: a simple reparameterization to accelerate training
119
       of deep neural networks*. arXiv 1602.07868.
120
121
    """
122
    @lazy(allocation=['input_dim'])
123
    def __init__(self, input_dim, broadcastable=None,
124
                 conserve_memory=True, epsilon=1e-4, scale_init=None,
125
                 shift_init=None, mean_only=False, **kwargs):
126
        self.input_dim = input_dim
127
        self.broadcastable = broadcastable
128
        self.conserve_memory = conserve_memory
129
        self.epsilon = epsilon
130
        self.scale_init = (Constant(1) if scale_init is None
131
                           else scale_init)
132
        self.shift_init = (Constant(0) if shift_init is None
133
                           else shift_init)
134
        self.mean_only = mean_only
135
        self._training_mode = []
136
        super(BatchNormalization, self).__init__(**kwargs)
137
138
    @application(inputs=['input_'], outputs=['output'])
139
    def apply(self, input_, application_call):
140
        if self._training_mode:
141
            mean, stdev = self._compute_training_statistics(input_)
142
        else:
143
            mean, stdev = self._prepare_population_statistics()
144
        # Useful for filtration of calls that were already made in
145
        # training mode when doing graph transformations.
146
        # Very important to cast to bool, as self._training_mode is
147
        # normally a list (to support nested context managers), which would
148
        # otherwise get passed by reference and be remotely mutated.
149
        application_call.metadata['training_mode'] = bool(self._training_mode)
150
        # Useful for retrieving a list of updates for population
151
        # statistics. Ditch the broadcastable first axis, though, to
152
        # make it the same dimensions as the population mean and stdev
153
        # shared variables.
154
        application_call.metadata['offset'] = mean[0]
155
        application_call.metadata['divisor'] = stdev[0]
156
        # Give these quantities roles in the graph.
157
        _add_role_and_annotate(mean, BATCH_NORM_OFFSET,
158
                               [self, application_call])
159
        if self.mean_only:
160
            scale = tensor.ones_like(self.shift)
161
            stdev = tensor.ones_like(mean)
162
        else:
163
            scale = self.scale
164
            # The annotation/role information is useless if it's a constant.
165
            _add_role_and_annotate(stdev, BATCH_NORM_DIVISOR,
166
                                   [self, application_call])
167
        shift = _add_batch_axis(self.shift)
168
        scale = _add_batch_axis(scale)
169
        # Heavy lifting is done by the Theano utility function.
170
        normalized = bn.batch_normalization(input_, scale, shift, mean, stdev,
171
                                            mode=('low_mem'
172
                                                  if self.conserve_memory
173
                                                  else 'high_mem'))
174
        return normalized
175
176
    def __enter__(self):
177
        self._training_mode.append(True)
178
179
    def __exit__(self, *exc_info):
180
        self._training_mode.pop()
181
182
    def _compute_training_statistics(self, input_):
183
        axes = (0,) + tuple((i + 1) for i, b in
184
                            enumerate(self.population_mean.broadcastable)
185
                            if b)
186
        mean = input_.mean(axis=axes, keepdims=True)
187
        assert mean.broadcastable[1:] == self.population_mean.broadcastable
188
        add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE)
189
        if self.mean_only:
190
            stdev = tensor.ones_like(mean)
191
        else:
192
            stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) +
193
                                numpy.cast[theano.config.floatX](self.epsilon))
194
            assert (stdev.broadcastable[1:] ==
195
                    self.population_stdev.broadcastable)
196
            add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE)
197
        return mean, stdev
198
199
    def _prepare_population_statistics(self):
200
        mean = _add_batch_axis(self.population_mean)
201
        if self.mean_only:
202
            stdev = tensor.ones_like(self.population_mean)
203
        else:
204
            stdev = self.population_stdev
205
        stdev = _add_batch_axis(stdev)
206
        return mean, stdev
207
208
    def _allocate(self):
209
        input_dim = ((self.input_dim,)
210
                     if not isinstance(self.input_dim, collections.Sequence)
211
                     else self.input_dim)
212
        broadcastable = (tuple(False for _ in input_dim)
213
                         if self.broadcastable is None else self.broadcastable)
214
        if len(input_dim) != len(broadcastable):
215
            raise ValueError("input_dim and broadcastable must be same length")
216
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
217
                        equizip(input_dim, broadcastable))
218
        broadcastable = broadcastable
219
220
        # "beta", from the Ioffe & Szegedy manuscript.
221
        self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift',
222
                                        broadcastable=broadcastable)
223
        add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER)
224
        self.parameters.append(self.shift)
225
226
        # These aren't technically parameters, in that they should not be
227
        # learned using the same cost function as other model parameters.
228
        self.population_mean = shared_floatx_zeros(var_dim,
229
                                                   name='population_mean',
230
                                                   broadcastable=broadcastable)
231
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
232
233
        # Normally these would get annotated by an AnnotatingList, but they
234
        # aren't in self.parameters.
235
        add_annotation(self.population_mean, self)
236
237
        if not self.mean_only:
238
            # "gamma", from the Ioffe & Szegedy manuscript.
239
            self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale',
240
                                            broadcastable=broadcastable)
241
242
            add_role(self.scale, BATCH_NORM_SCALE_PARAMETER)
243
            self.parameters.append(self.scale)
244
245
            self.population_stdev = shared_floatx(numpy.ones(var_dim),
246
                                                  name='population_stdev',
247
                                                  broadcastable=broadcastable)
248
            add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
249
            add_annotation(self.population_stdev, self)
250
251
    def _initialize(self):
252
        self.shift_init.initialize(self.shift, self.rng)
253
        if not self.mean_only:
254
            self.scale_init.initialize(self.scale, self.rng)
255
256
    # Needed for the Feedforward interface.
257
    @property
258
    def output_dim(self):
259
        return self.input_dim
260
261
    # The following properties allow for BatchNormalization bricks
262
    # to be used directly inside of a ConvolutionalSequence.
263
    @property
264
    def image_size(self):
265
        return self.input_dim[-2:]
266
267
    @image_size.setter
268
    def image_size(self, value):
269
        if not isinstance(self.input_dim, collections.Sequence):
270
            self.input_dim = (None,) + tuple(value)
271
        else:
272
            self.input_dim = (self.input_dim[0],) + tuple(value)
273
274
    @property
275
    def num_channels(self):
276
        return self.input_dim[0]
277
278
    @num_channels.setter
279
    def num_channels(self, value):
280
        if not isinstance(self.input_dim, collections.Sequence):
281
            self.input_dim = (value,) + (None, None)
282
        else:
283
            self.input_dim = (value,) + self.input_dim[-2:]
284
285
    def get_dim(self, name):
286
        if name in ('input', 'output'):
287
            return self.input_dim
288
        else:
289
            raise KeyError
290
291
    @property
292
    def num_output_channels(self):
293
        return self.num_channels
294
295
296
class SpatialBatchNormalization(BatchNormalization):
297
    """Convenient subclass for batch normalization across spatial inputs.
298
299
    Parameters
300
    ----------
301
    input_dim : int or tuple
302
        The input size of a single example. Must be length at least 2.
303
        It's assumed that the first axis of this tuple is a "channels"
304
        axis, which should not be summed over, and all remaining
305
        dimensions are spatial dimensions.
306
307
    Notes
308
    -----
309
    See :class:`BatchNormalization` for more details (and additional
310
    keyword arguments).
311
312
    """
313
    def _allocate(self):
314
        if not isinstance(self.input_dim,
315
                          collections.Sequence) or len(self.input_dim) < 2:
316
            raise ValueError('expected input_dim to be length >= 2 '
317
                             'e.g. (channels, height, width)')
318
        self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1))
319
        super(SpatialBatchNormalization, self)._allocate()
320
321
322
class BatchNormalizedMLP(MLP):
323
    """Convenient subclass for building an MLP with batch normalization.
324
325
    Parameters
326
    ----------
327
    conserve_memory : bool, optional
328
        See :class:`BatchNormalization`.
329
    mean_only : bool, optional
330
        See :class:`BatchNormalization`.
331
332
    Notes
333
    -----
334
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
335
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
336
    containing an appropriate :class:`BatchNormalization` brick and
337
    the activation that follows it.
338
339
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
340
    not contain any biases, as they could be canceled out by the biases
341
    in the :class:`BatchNormalization` bricks being added. Pass
342
    `use_bias` with a value of `True` if you really want this for some
343
    reason.
344
345
    """
346
    @lazy(allocation=['dims'])
347
    def __init__(self, activations, dims, *args, **kwargs):
348
        self._conserve_memory = kwargs.pop('conserve_memory', True)
349
        self._mean_only = kwargs.pop('mean_only', False)
350
        activations = [
351
            Sequence([
352
                BatchNormalization(conserve_memory=self._conserve_memory,
353
                                   mean_only=self._mean_only).apply,
354
                act.apply
355
            ], name='batch_norm_activation_{}'.format(i))
356
            for i, act in enumerate(activations)
357
        ]
358
        # Batch normalization bricks incorporate a bias, so there's no
359
        # need for our Linear bricks to have them.
360
        kwargs.setdefault('use_bias', False)
361
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
362
                                                 **kwargs)
363
364
    def _nested_brick_property_getter(self, property_name):
365
        return getattr(self, '_' + property_name)
366
367
    def _nested_brick_property_setter(self, value, property_name):
368
        setattr(self, '_' + property_name, value)
369
        for act in self.activations:
370
            assert isinstance(act.children[0], BatchNormalization)
371
            setattr(act.children[0], property_name, value)
372
373
    conserve_memory = property(partial(_nested_brick_property_getter,
374
                                       property_name='conserve_memory'),
375
                               partial(_nested_brick_property_setter,
376
                                       property_name='conserve_memory'))
377
378
    mean_only = property(partial(_nested_brick_property_getter,
379
                                 property_name='mean_only'),
380
                         partial(_nested_brick_property_setter,
381
                                 property_name='mean_only'))
382
383
    def _push_allocation_config(self):
384
        super(BatchNormalizedMLP, self)._push_allocation_config()
385
        # Do the extra allocation pushing for the BatchNormalization
386
        # bricks. They need as their input dimension the output dimension
387
        # of each linear transformation.  Exclude the first dimension,
388
        # which is the input dimension.
389
        for act, dim in equizip(self.activations, self.dims[1:]):
390
            assert isinstance(act.children[0], BatchNormalization)
391
            act.children[0].input_dim = dim
392