Completed
Pull Request — master (#941)
by David
01:46
created

blocks.bricks.BatchNormalization.add_batch_axis()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 4
rs 10
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
from theano.tensor.nnet import bn
6
7
from ..graph import add_annotation
8
from ..initialization import Constant
9
from ..roles import (WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN,
10
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
11
                     BATCH_NORM_DIVISOR, add_role)
12
from ..utils import (shared_floatx_zeros, shared_floatx,
13
                     shared_floatx_nans)
14
from .base import lazy, application
15
from .sequences import Sequence, Feedforward, MLP
16
from .interfaces import RNGMixin
17
18
19
class BatchNormalization(RNGMixin, Feedforward):
20
    r"""Normalizes activations, parameterizes a scale and shift.
21
22
    Parameters
23
    ----------
24
    input_dim : int or tuple
25
        Shape of a single input example. It is assumed that a batch axis
26
        will be prepended to this.
27
    broadcastable : tuple, optional
28
        Tuple the same length as `input_dim` which specifies which of the
29
        per-example axes should be averaged over to compute means and
30
        standard deviations. For example, in order to normalize over all
31
        spatial locations in a `(batch_index, channels, height, width)`
32
        image, pass `(False, True, True)`.
33
    save_memory : bool, optional
34
        Use an implementation that stores less intermediate state and
35
        therefore uses less memory, at the expense of 5-10% speed. Default
36
        is `True`.
37
    weights_init : object, optional
38
        Initialization object to use for the learned scaling parameter
39
        ($\\gamma$ in [BN]_). By default, uses constant initialization
40
        of 1.
41
    biases_init : object, optional
42
        Initialization object to use for the learned shift parameter
43
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
44
45
    Notes
46
    -----
47
    In order for trained models to behave sensibly immediately upon
48
    upon deserialization, by default, this brick runs in *inference* mode,
49
    using a population mean and population standard deviation (initialized
50
    to zeros and ones respectively) to normalize activations. It is
51
    expected that the user will adapt these during training in some
52
    fashion, independently of the training objective, e.g. by taking a
53
    moving average of minibatch-wise statistics.
54
55
    In order to *train* with batch normalization, one must obtain a
56
    training graph by transforming the original inference graph.  See
57
    :func:`batch_normalize`.
58
59
    This Brick accepts `weights_init` and `biases_init` arguments but is
60
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
61
    therefore not receive pushed initialization config from any parent
62
    brick. In almost all cases, you will probably want to stick with the
63
    defaults (unit scale and zero shift), but you can explicitly pass one
64
    or both initializers to override this.
65
66
    This has the necessary properties to be inserted into a
67
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
68
    the `input_dim` should be omitted at construction, to be inferred from
69
    the layer below.
70
71
    """
72
    @lazy(allocation=['input_dim'])
73
    def __init__(self, input_dim, broadcastable=None,
74
                 save_memory=True, weights_init=None,
75
                 biases_init=None, **kwargs):
76
        self.input_dim = input_dim
77
        self.broadcastable = broadcastable
78
        self.save_memory = save_memory
79
        self.weights_init = (Constant(1) if weights_init is None
80
                             else weights_init)
81
        self.biases_init = (Constant(0) if biases_init is None
82
                            else biases_init)
83
        super(BatchNormalization, self).__init__(**kwargs)
84
85
    @application(inputs=['input_'], outputs=['output'])
86
    def apply(self, input_, application_call):
87
        def add_batch_axis(var, name=None):
88
            new_var = var.dimshuffle('x', *list(range(var.ndim)))
89
            new_var.name = name
90
            return new_var
91
92
        def annotate(var, role):
93
            add_role(var, role)
94
            add_annotation(var, self)
95
            add_annotation(var, application_call)
96
97
        mean = add_batch_axis(self.population_mean, 'population_offset')
98
        annotate(mean, BATCH_NORM_OFFSET)
99
100
        stdev = add_batch_axis(self.population_stdev, 'population_divisor')
101
        annotate(stdev, BATCH_NORM_DIVISOR)
102
        W = add_batch_axis(self.W)
103
        b = add_batch_axis(self.b)
104
        # Heavy lifting is done by the Theano utility function.
105
        normalized = bn.batch_normalization(input_, W, b, mean, stdev,
106
                                            mode=('low_mem' if self.save_memory
107
                                                  else 'high_mem'))
108
        return normalized
109
110
    def _allocate(self):
111
        input_dim = ((self.input_dim,)
112
                     if not isinstance(self.input_dim, collections.Sequence)
113
                     else self.input_dim)
114
        broadcastable = (tuple(False for _ in range(len(input_dim)))
115
                         if self.broadcastable is None else self.broadcastable)
116
        if len(input_dim) != len(broadcastable):
117
            raise ValueError("input_dim and broadcastable must be same length")
118
        var_dim = tuple(1 if broadcast else dim for dim, broadcast in
119
                        equizip(input_dim, broadcastable))
120
        broadcastable = broadcastable
121
122
        # "gamma", from the Ioffe & Szegedy manuscript.
123
        self._W = shared_floatx_nans(var_dim, name='batch_norm_scale',
124
                                     broadcastable=broadcastable)
125
126
        # "beta", from the Ioffe & Szegedy manuscript.
127
        self._b = shared_floatx_nans(var_dim, name='batch_norm_shift',
128
                                     broadcastable=broadcastable)
129
        add_role(self.W, WEIGHT)
130
        add_role(self.b, BIAS)
131
        self.parameters.append(self.W)
132
        self.parameters.append(self.b)
133
134
        # These aren't technically parameters, in that they should not be
135
        # learned using the same cost function as other model parameters.
136
        self.population_mean = shared_floatx_zeros(var_dim,
137
                                                   name='population_mean',
138
                                                   broadcastable=broadcastable)
139
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
140
                                              name='population_stdev',
141
                                              broadcastable=broadcastable)
142
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
143
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
144
145
    @property
146
    def W(self):
147
        return self._W
148
149
    @property
150
    def b(self):
151
        return self._b
152
153
    def _initialize(self):
154
        self.biases_init.initialize(self.b, self.rng)
155
        self.weights_init.initialize(self.W, self.rng)
156
157
    # Needed for the Feedforward interface.
158
    @property
159
    def output_dim(self):
160
        return self.input_dim
161
162
    # The following properties allow for BatchNormalization bricks
163
    # to be used directly inside of a ConvolutionalSequence.
164
    @property
165
    def image_size(self):
166
        return self.input_dim[-2:]
167
168
    @image_size.setter
169
    def image_size(self, value):
170
        if not isinstance(self.input_dim, collections.Sequence):
171
            self.input_dim = (None,) + tuple(value)
172
        else:
173
            self.input_dim = (self.input_dim[0],) + tuple(value)
174
175
    @property
176
    def num_channels(self):
177
        return self.input_dim[0]
178
179
    @num_channels.setter
180
    def num_channels(self, value):
181
        if not isinstance(self.input_dim, collections.Sequence):
182
            self.input_dim = (value,) + (None, None)
183
        else:
184
            self.input_dim = (value,) + self.input_dim[-2:]
185
186
    def get_dim(self, name):
187
        if name in ('input', 'output'):
188
            return self.input_dim
189
        else:
190
            raise KeyError
191
192
    @property
193
    def num_output_channels(self):
194
        return self.num_channels
195
196
197
class SpatialBatchNormalization(BatchNormalization):
198
    """Convenient subclass for batch normalization across spatial inputs.
199
200
    Parameters
201
    ----------
202
    input_dim : int or tuple
203
        The input size of a single example. Must be length at least 2.
204
        It's assumed that the first axis of this tuple is a "channels"
205
        axis, which should not be summed over, and all remaining
206
        dimensions are spatial dimensions.
207
208
    Notes
209
    -----
210
    See :class:`BatchNormalization` for more details (and additional
211
    keyword arguments).
212
213
    """
214
    @lazy(allocation=['input_dim'])
215
    def __init__(self, input_dim, **kwargs):
216
        if not isinstance(input_dim,
217
                          collections.Sequence) or len(input_dim) < 2:
218
            raise ValueError('expected input_dim to be length >= 2 '
219
                             '(channels, height, width)')
220
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
221
        kwargs.setdefault('broadcastable', broadcastable)
222
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
223
224
225
class BatchNormalizedMLP(MLP):
226
    """Convenient subclass for building an MLP with batch normalization.
227
228
    Parameters
229
    ----------
230
    save_memory : bool, optional
231
        See :class:`BatchNormalization`.
232
233
    Notes
234
    -----
235
    All other parameters are the same as :class:`~blocks.bricks.MLP`. Each
236
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
237
    containing an appropriate :class:`BatchNormalization` brick and
238
    the activation that follows it.
239
240
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
241
    not contain any biases, as they could be canceled out by the biases
242
    in the :class:`BatchNormalization` bricks being added. Pass
243
    `use_bias` with a value of `True` if you really want this for some
244
    reason.
245
246
    """
247
    @lazy(allocation=['dims'])
248
    def __init__(self, activations, dims, *args, **kwargs):
249
        save_memory = kwargs.pop('save_memory', True)
250
        activations = [
251
            Sequence([BatchNormalization(save_memory=save_memory).apply,
252
                      act.apply], name='batch_norm_activation_{}'.format(i))
253
            for i, act in enumerate(activations)
254
        ]
255
        # Batch normalization bricks incorporate a bias, so there's no
256
        # need for our Linear bricks to have them.
257
        kwargs.setdefault('use_bias', False)
258
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
259
                                                 **kwargs)
260
261
    @property
262
    def save_memory(self):
263
        return self._save_memory
264
265
    @save_memory.setter
266
    def save_memory(self, value):
267
        self._save_memory = value
268
        for act in self.activations:
269
            assert isinstance(act.children[0], BatchNormalization)
270
            act.children[0].save_memory = value
271
272
    def _push_allocation_config(self):
273
        super(BatchNormalizedMLP, self)._push_allocation_config()
274
        # Do the extra allocation pushing for the BatchNormalization
275
        # bricks. They need as their input dimension the output dimension
276
        # of each linear transformation.  Exclude the first dimension,
277
        # which is the input dimension.
278
        for act, dim in equizip(self.activations, self.dims[1:]):
279
            act.children[0].input_dim = dim
280