Completed
Pull Request — master (#941)
by David
01:49
created

blocks.bricks.BatchNormalization.apply()   A

Complexity

Conditions 3

Size

Total Lines 19

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 19
rs 9.4285

1 Method

Rating   Name   Duplication   Size   Complexity  
A blocks.bricks.BatchNormalization.annotate() 0 4 1
1
import collections
2
3
import numpy
4
from picklable_itertools.extras import equizip
5
from theano import tensor
0 ignored issues
show
Unused Code introduced by
Unused tensor imported from theano
Loading history...
6
from theano.tensor.nnet import bn
7
8
from ..graph import add_annotation
9
from ..initialization import Constant
10
from ..filter import VariableFilter, get_application_call
0 ignored issues
show
Unused Code introduced by
Unused get_application_call imported from filter
Loading history...
Unused Code introduced by
Unused VariableFilter imported from filter
Loading history...
11
from ..roles import (INPUT, WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN,
0 ignored issues
show
Unused Code introduced by
Unused BATCH_NORM_MINIBATCH_ESTIMATE imported from roles
Loading history...
Unused Code introduced by
Unused INPUT imported from roles
Loading history...
12
                     BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET,
13
                     BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE,
14
                     add_role)
15
from ..utils import (shared_floatx_zeros, shared_floatx,
16
                     shared_floatx_nans)
17
from .base import lazy, application
18
from .sequences import Sequence, Feedforward, MLP
19
from .interfaces import RNGMixin
20
21
22
class BatchNormalization(RNGMixin, Feedforward):
23
    """Normalizes activations, parameterizes a scale and shift.
24
25
    Parameters
26
    ----------
27
    input_dim : int or tuple
28
        Shape of a single input example. It is assumed that a batch axis
29
        will be prepended to this.
30
    broadcastable : tuple, optional
31
        Tuple the same length as `input_dim` which specifies which of the
32
        per-example axes should be averaged over to compute means and
33
        standard deviations. For example, in order to normalize over all
34
        spatial locations in a `(batch_index, channels, height, width)`
35
        image, pass `(False, True, True)`.
36
    save_memory : bool, optional
37
        Use an implementation that stores less intermediate state and
38
        therefore uses less memory, at the expense of 5-10% speed. Default
39
        is `True`.
40
    weights_init : object, optional
41
        Initialization object to use for the learned scaling parameter
42
        ($\\gamma$ in [BN]_). By default, uses constant initialization
43
        of 1.
44
    biases_init : object, optional
45
        Initialization object to use for the learned shift parameter
46
        ($\\beta$ in [BN]_). By default, uses constant initialization of 0.
47
48
    Notes
49
    -----
50
    In order for trained models to behave sensibly immediately upon
51
    upon deserialization, by default, this brick runs in *inference* mode,
52
    using a population mean and population standard deviation (initialized
53
    to zeros and ones respectively) to normalize activations. It is
54
    expected that the user will adapt these during training in some
55
    fashion, independently of the training objective, e.g. by taking a
56
    moving average of minibatch-wise statistics.
57
58
    In order to *train* with batch normalization, one must obtain a
59
    training graph by transforming the original inference graph.  See
60
    :func:`batch_normalize`.
61
62
    This Brick accepts `weights_init` and `biases_init` arguments but is
63
    *not* an instance of :class:`~blocks.bricks.Initializable`, and will
64
    therefore not receive pushed initialization config from any parent
65
    brick. In almost all cases, you will probably want to stick with the
66
    defaults (unit scale and zero shift), but you can explicitly pass one
67
    or both initializers to override this.
68
69
    This has the necessary properties to be inserted into a
70
    :class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case
71
    the `input_dim` should be omitted at construction, to be inferred from
72
    the layer below.
73
74
    """
75
    @lazy(allocation=['input_dim'])
76
    def __init__(self, input_dim, broadcastable=None,
77
                 save_memory=True, weights_init=None,
78
                 biases_init=None, **kwargs):
79
        self.input_dim = input_dim
80
        self.broadcastable = broadcastable
81
        self.save_memory = save_memory
82
        self.weights_init = (Constant(1) if weights_init is None
83
                             else weights_init)
84
        self.biases_init = (Constant(0) if biases_init is None
85
                            else biases_init)
86
        super(BatchNormalization, self).__init__(**kwargs)
87
88
    @application(inputs=['input_'], outputs=['output'])
89
    def apply(self, input_, application_call):
90
        mean = self.population_mean.copy(name='population_offset')
91
        stdev = self.population_stdev.copy(name='population_divisor')
92
93
        def annotate(var, role):
94
            add_role(var, role)
95
            add_annotation(var, self)
96
            add_annotation(var, application_call)
97
98
        annotate(mean, BATCH_NORM_OFFSET)
99
        annotate(stdev, BATCH_NORM_DIVISOR)
100
101
        # Heavy lifting is done by the Theano utility function.
102
        normalized = bn.batch_normalization(input_, self.W,
103
                                            self.b, mean, stdev,
104
                                            mode=('low_mem' if self.save_memory
105
                                                  else 'high_mem'))
106
        return normalized
107
108
    def _allocate(self):
109
        input_dim = ((self.input_dim,)
110
                     if not isinstance(self.input_dim, collections.Sequence)
111
                     else self.input_dim)
112
        broadcastable = (tuple(False for _ in range(len(input_dim)))
113
                         if self.broadcastable is None else self.broadcastable)
114
        if len(input_dim) != len(broadcastable):
115
            raise ValueError("input_dim and broadcastable must be same length")
116
        var_dim = ((1,) +  # batch axis
117
                   tuple(1 if broadcast else dim for dim, broadcast in
118
                         equizip(input_dim, broadcastable)))
119
        broadcastable = (True,) + broadcastable
120
121
        # "gamma", from the Ioffe & Szegedy manuscript.
122
        self._W = shared_floatx_nans(var_dim, name='batch_norm_scale',
123
                                     broadcastable=broadcastable)
124
125
        # "beta", from the Ioffe & Szegedy manuscript.
126
        self._b = shared_floatx_nans(var_dim, name='batch_norm_shift',
127
                                     broadcastable=broadcastable)
128
        add_role(self.W, WEIGHT)
129
        add_role(self.b, BIAS)
130
        self.parameters.append(self.W)
131
        self.parameters.append(self.b)
132
133
        # These aren't technically parameters, in that they should not be
134
        # learned using the same cost function as other model parameters.
135
        self.population_mean = shared_floatx_zeros(var_dim,
136
                                                   name='population_mean',
137
                                                   broadcastable=broadcastable)
138
        self.population_stdev = shared_floatx(numpy.ones(var_dim),
139
                                              name='population_stdev',
140
                                              broadcastable=broadcastable)
141
        add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN)
142
        add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV)
143
144
    @property
145
    def W(self):
146
        return self._W
147
148
    @property
149
    def b(self):
150
        return self._b
151
152
    def _initialize(self):
153
        self.biases_init.initialize(self.b, self.rng)
154
        self.weights_init.initialize(self.W, self.rng)
155
156
    # Needed for the Feedforward interface.
157
    @property
158
    def output_dim(self):
159
        return self.input_dim
160
161
    # The following properties allow for BatchNormalization bricks
162
    # to be used directly inside of a ConvolutionalSequence.
163
    @property
164
    def image_size(self):
165
        return self.input_dim[-2:]
166
167
    @image_size.setter
168
    def image_size(self, value):
169
        if not isinstance(self.input_dim, collections.Sequence):
170
            self.input_dim = (None,) + tuple(value)
171
        else:
172
            self.input_dim = (self.input_dim[0],) + tuple(value)
173
174
    @property
175
    def num_channels(self):
176
        return self.input_dim[0]
177
178
    @num_channels.setter
179
    def num_channels(self, value):
180
        if not isinstance(self.input_dim, collections.Sequence):
181
            self.input_dim = (value,) + (None, None)
182
        else:
183
            self.input_dim = (value,) + self.input_dim[-2:]
184
185
    def get_dim(self, name):
186
        if name in ('input', 'output'):
187
            return self.input_dim
188
        else:
189
            raise KeyError
190
191
    @property
192
    def num_output_channels(self):
193
        return self.num_channels
194
195
196
class SpatialBatchNormalization(BatchNormalization):
197
    """Convenient subclass for batch normalization across spatial inputs.
198
199
    Parameters
200
    ----------
201
    input_dim : int or tuple
202
        The input size of a single example. Must be length at least 2.
203
        It's assumed that the first axis of this tuple is a "channels"
204
        axis, which should not be summed over, and all remaining
205
        dimensions are spatial dimensions.
206
207
    Notes
208
    -----
209
    See :class:`BatchNormalization` for more details (and additional
210
    keyword arguments).
211
212
    """
213
    @lazy(allocation=['input_dim'])
214
    def __init__(self, input_dim, **kwargs):
215
        if not isinstance(input_dim,
216
                          collections.Sequence) or len(input_dim) < 2:
217
            raise ValueError('expected input_dim to be length >= 2 '
218
                             '(channels, height, width)')
219
        broadcastable = (False,) + ((True,) * (len(input_dim) - 1))
220
        kwargs.setdefault('broadcastable', broadcastable)
221
        super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs)
222
223
224
class BatchNormalizedMLP(MLP):
225
    """Convenient subclass for building an MLP with batch normalization.
226
227
    Notes
228
    -----
229
    All parameters are the same as :class:`~blocks.bricks.MLP`. Each
230
    activation brick is wrapped in a :class:`~blocks.bricks.Sequence`
231
    containing an appropriate :class:`BatchNormalization` brick and
232
    the activation that follows it.
233
234
    By default, the contained :class:`~blocks.bricks.Linear` bricks will
235
    not contain any biases, as they could be canceled out by the biases
236
    in the :class:`BatchNormalization` bricks being added. Pass
237
    `use_bias` with a value of `True` if you really want this for some
238
    reason.
239
240
    """
241
    @lazy(allocation=['dims'])
242
    def __init__(self, activations, dims, *args, **kwargs):
243
        activations = [Sequence([BatchNormalization().apply, act.apply],
244
                                name='batch_norm_activation_{}'.format(i))
245
                       for i, act in enumerate(activations)]
246
        # Batch normalization bricks incorporate a bias, so there's no
247
        # need for our Linear bricks to have them.
248
        kwargs.setdefault('use_bias', False)
249
        super(BatchNormalizedMLP, self).__init__(activations, dims, *args,
250
                                                 **kwargs)
251
252
    def _push_allocation_config(self):
253
        super(BatchNormalizedMLP, self)._push_allocation_config()
254
        # Do the extra allocation pushing for the BatchNormalization
255
        # bricks. They need as their input dimension the output dimension
256
        # of each linear transformation.  Exclude the first dimension,
257
        # which is the input dimension.
258
        for act, dim in equizip(self.activations, self.dims[1:]):
259
            act.children[0].input_dim = dim
260
261