Completed
Pull Request — master (#941)
by David
01:45
created

blocks.bricks.SpatialBatchNormalization   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 26
Duplicated Lines 0 %
Metric Value
dl 0
loc 26
rs 10
wmc 3

1 Method

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 9 3
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
from theano.tensor.nnet import bn
6
7
from ..graph import add_annotation
8
from ..initialization import Constant
9
from ..roles import (WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN,
10
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
11
                     BATCH_NORM_DIVISOR, add_role)
12
from ..utils import (shared_floatx_zeros, shared_floatx,
13
                     shared_floatx_nans)
14
from .base import lazy, application
15
from .sequences import Sequence, Feedforward, MLP
16
from .interfaces import RNGMixin
17
18
19
class BatchNormalization(RNGMixin, Feedforward):
20
    """Normalizes activations, parameterizes a scale and shift.
21
22
    Parameters
23
    ----------
24
    input_dim : int or tuple
25
        Shape of a single input example. It is assumed that a batch axis
26
        will be prepended to this.
27
    broadcastable : tuple, optional
28
        Tuple the same length as `input_dim` which specifies which of the
29
        per-example axes should be averaged over to compute means and
30
        standard deviations. For example, in order to normalize over all
31
        spatial locations in a `(batch_index, channels, height, width)`
32
        image, pass `(False, True, True)`.
33
    save_memory : bool, optional
34
        Use an implementation that stores less intermediate state and
35
        therefore uses less memory, at the expense of 5-10% speed. Default
36
        is `True`.
37
    weights_init : object, optional
38
        Initialization object to use for the learned scaling parameter
39
        ($\\gamma$ in [BN]_). By default, uses constant initialization
40
        of 1.
41
    biases_init : object, optional
42
        Initialization object to use for the learned shift parameter
43
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
44
45
    Notes
46
    -----
47
    In order for trained models to behave sensibly immediately upon
48
    upon deserialization, by default, this brick runs in *inference* mode,
49
    using a population mean and population standard deviation (initialized
50
    to zeros and ones respectively) to normalize activations. It is
51
    expected that the user will adapt these during training in some
52
    fashion, independently of the training objective, e.g. by taking a
53
    moving average of minibatch-wise statistics.
54
55
    In order to *train* with batch normalization, one must obtain a
56
    training graph by transforming the original inference graph.  See
57
    :func:`batch_normalize`.
58
59
    This Brick accepts `weights_init` and `biases_init` arguments but is
60
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
61
    therefore not receive pushed initialization config from any parent
62
    brick. In almost all cases, you will probably want to stick with the
63
    defaults (unit scale and zero shift), but you can explicitly pass one
64
    or both initializers to override this.
65
66
    This has the necessary properties to be inserted into a
67
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
68
    the `input_dim` should be omitted at construction, to be inferred from
69
    the layer below.
70
71
    """
72
    @lazy(allocation=['input_dim'])
73
    def __init__(self, input_dim, broadcastable=None,
74
                 save_memory=True, weights_init=None,
75
                 biases_init=None, **kwargs):
76
        self.input_dim = input_dim
77
        self.broadcastable = broadcastable
78
        self.save_memory = save_memory
79
        self.weights_init = (Constant(1) if weights_init is None
80
                             else weights_init)
81
        self.biases_init = (Constant(0) if biases_init is None
82
                            else biases_init)
83
        super(BatchNormalization, self).__init__(**kwargs)
84
85
    @application(inputs=['input_'], outputs=['output'])
86
    def apply(self, input_, application_call):
87
        mean = self.population_mean.copy(name='population_offset')
88
        stdev = self.population_stdev.copy(name='population_divisor')
89
90
        def annotate(var, role):
91
            add_role(var, role)
92
            add_annotation(var, self)
93
            add_annotation(var, application_call)
94
95
        annotate(mean, BATCH_NORM_OFFSET)
96
        annotate(stdev, BATCH_NORM_DIVISOR)
97
98
        # Heavy lifting is done by the Theano utility function.
99
        normalized = bn.batch_normalization(input_, self.W,
100
                                            self.b, mean, stdev,
101
                                            mode=('low_mem' if self.save_memory
102
                                                  else 'high_mem'))
103
        return normalized
104
105
    def _allocate(self):
106
        input_dim = ((self.input_dim,)
107
                     if not isinstance(self.input_dim, collections.Sequence)
108
                     else self.input_dim)
109
        broadcastable = (tuple(False for _ in range(len(input_dim)))
110
                         if self.broadcastable is None else self.broadcastable)
111
        if len(input_dim) != len(broadcastable):
112
            raise ValueError("input_dim and broadcastable must be same length")
113
        var_dim = ((1,) +  # batch axis
114
                   tuple(1 if broadcast else dim for dim, broadcast in
115
                         equizip(input_dim, broadcastable)))
116
        broadcastable = (True,) + broadcastable
117
118
        # "gamma", from the Ioffe & Szegedy manuscript.
119
        self._W = shared_floatx_nans(var_dim, name='batch_norm_scale',
120
                                     broadcastable=broadcastable)
121
122
        # "beta", from the Ioffe & Szegedy manuscript.
123
        self._b = shared_floatx_nans(var_dim, name='batch_norm_shift',
124
                                     broadcastable=broadcastable)
125
        add_role(self.W, WEIGHT)
126
        add_role(self.b, BIAS)
127
        self.parameters.append(self.W)
128
        self.parameters.append(self.b)
129
130
        # These aren't technically parameters, in that they should not be
131
        # learned using the same cost function as other model parameters.
132
        self.population_mean = shared_floatx_zeros(var_dim,
133
                                                   name='population_mean',
134
                                                   broadcastable=broadcastable)
135
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
136
                                              name='population_stdev',
137
                                              broadcastable=broadcastable)
138
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
139
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
140
141
    @property
142
    def W(self):
143
        return self._W
144
145
    @property
146
    def b(self):
147
        return self._b
148
149
    def _initialize(self):
150
        self.biases_init.initialize(self.b, self.rng)
151
        self.weights_init.initialize(self.W, self.rng)
152
153
    # Needed for the Feedforward interface.
154
    @property
155
    def output_dim(self):
156
        return self.input_dim
157
158
    # The following properties allow for BatchNormalization bricks
159
    # to be used directly inside of a ConvolutionalSequence.
160
    @property
161
    def image_size(self):
162
        return self.input_dim[-2:]
163
164
    @image_size.setter
165
    def image_size(self, value):
166
        if not isinstance(self.input_dim, collections.Sequence):
167
            self.input_dim = (None,) + tuple(value)
168
        else:
169
            self.input_dim = (self.input_dim[0],) + tuple(value)
170
171
    @property
172
    def num_channels(self):
173
        return self.input_dim[0]
174
175
    @num_channels.setter
176
    def num_channels(self, value):
177
        if not isinstance(self.input_dim, collections.Sequence):
178
            self.input_dim = (value,) + (None, None)
179
        else:
180
            self.input_dim = (value,) + self.input_dim[-2:]
181
182
    def get_dim(self, name):
183
        if name in ('input', 'output'):
184
            return self.input_dim
185
        else:
186
            raise KeyError
187
188
    @property
189
    def num_output_channels(self):
190
        return self.num_channels
191
192
193
class SpatialBatchNormalization(BatchNormalization):
194
    """Convenient subclass for batch normalization across spatial inputs.
195
196
    Parameters
197
    ----------
198
    input_dim : int or tuple
199
        The input size of a single example. Must be length at least 2.
200
        It's assumed that the first axis of this tuple is a "channels"
201
        axis, which should not be summed over, and all remaining
202
        dimensions are spatial dimensions.
203
204
    Notes
205
    -----
206
    See :class:`BatchNormalization` for more details (and additional
207
    keyword arguments).
208
209
    """
210
    @lazy(allocation=['input_dim'])
211
    def __init__(self, input_dim, **kwargs):
212
        if not isinstance(input_dim,
213
                          collections.Sequence) or len(input_dim) < 2:
214
            raise ValueError('expected input_dim to be length >= 2 '
215
                             '(channels, height, width)')
216
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
217
        kwargs.setdefault('broadcastable', broadcastable)
218
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
219
220
221
class BatchNormalizedMLP(MLP):
222
    """Convenient subclass for building an MLP with batch normalization.
223
224
    Notes
225
    -----
226
    All parameters are the same as :class:`~blocks.bricks.MLP`. Each
227
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
228
    containing an appropriate :class:`BatchNormalization` brick and
229
    the activation that follows it.
230
231
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
232
    not contain any biases, as they could be canceled out by the biases
233
    in the :class:`BatchNormalization` bricks being added. Pass
234
    `use_bias` with a value of `True` if you really want this for some
235
    reason.
236
237
    """
238
    @lazy(allocation=['dims'])
239
    def __init__(self, activations, dims, *args, **kwargs):
240
        activations = [Sequence([BatchNormalization().apply, act.apply],
241
                                name='batch_norm_activation_{}'.format(i))
242
                       for i, act in enumerate(activations)]
243
        # Batch normalization bricks incorporate a bias, so there's no
244
        # need for our Linear bricks to have them.
245
        kwargs.setdefault('use_bias', False)
246
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
247
                                                 **kwargs)
248
249
    def _push_allocation_config(self):
250
        super(BatchNormalizedMLP, self)._push_allocation_config()
251
        # Do the extra allocation pushing for the BatchNormalization
252
        # bricks. They need as their input dimension the output dimension
253
        # of each linear transformation.  Exclude the first dimension,
254
        # which is the input dimension.
255
        for act, dim in equizip(self.activations, self.dims[1:]):
256
            act.children[0].input_dim = dim
257