|
1
|
|
|
import collections |
|
2
|
|
|
|
|
3
|
|
|
import numpy |
|
4
|
|
|
from picklable_itertools.extras import equizip |
|
5
|
|
|
from theano import tensor |
|
6
|
|
|
from theano.tensor.nnet import bn |
|
7
|
|
|
|
|
8
|
|
|
from ..graph import add_annotation |
|
9
|
|
|
from ..initialization import Constant |
|
10
|
|
|
from ..filter import VariableFilter, get_application_call |
|
11
|
|
|
from ..roles import (INPUT, WEIGHT, BIAS, BATCH_NORM_POPULATION_MEAN, |
|
12
|
|
|
BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET, |
|
13
|
|
|
BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE, |
|
14
|
|
|
add_role) |
|
15
|
|
|
from ..utils import (shared_floatx_zeros, shared_floatx, |
|
16
|
|
|
shared_floatx_nans) |
|
17
|
|
|
from .base import lazy, application |
|
18
|
|
|
from .sequences import Sequence, Feedforward, MLP |
|
19
|
|
|
from .interfaces import RNGMixin |
|
20
|
|
|
|
|
21
|
|
|
|
|
22
|
|
|
class BatchNormalization(RNGMixin, Feedforward): |
|
23
|
|
|
"""Normalizes activations, parameterizes a scale and shift. |
|
24
|
|
|
|
|
25
|
|
|
Parameters |
|
26
|
|
|
---------- |
|
27
|
|
|
input_dim : int or tuple |
|
28
|
|
|
Shape of a single input example. It is assumed that a batch axis |
|
29
|
|
|
will be prepended to this. |
|
30
|
|
|
broadcastable : tuple, optional |
|
31
|
|
|
Tuple the same length as `input_dim` which specifies which of the |
|
32
|
|
|
per-example axes should be averaged over to compute means and |
|
33
|
|
|
standard deviations. For example, in order to normalize over all |
|
34
|
|
|
spatial locations in a `(batch_index, channels, height, width)` |
|
35
|
|
|
image, pass `(False, True, True)`. |
|
36
|
|
|
save_memory : bool, optional |
|
37
|
|
|
Use an implementation that stores less intermediate state and |
|
38
|
|
|
therefore uses less memory, at the expense of 5-10% speed. Default |
|
39
|
|
|
is `True`. |
|
40
|
|
|
weights_init : object, optional |
|
41
|
|
|
Initialization object to use for the learned scaling parameter |
|
42
|
|
|
($\\gamma$ in [BN]_). By default, uses constant initialization |
|
43
|
|
|
of 1. |
|
44
|
|
|
biases_init : object, optional |
|
45
|
|
|
Initialization object to use for the learned shift parameter |
|
46
|
|
|
($\\beta$ in [BN]_). By default, uses constant initialization of 0. |
|
47
|
|
|
|
|
48
|
|
|
Notes |
|
49
|
|
|
----- |
|
50
|
|
|
In order for trained models to behave sensibly immediately upon |
|
51
|
|
|
upon deserialization, by default, this brick runs in *inference* mode, |
|
52
|
|
|
using a population mean and population standard deviation (initialized |
|
53
|
|
|
to zeros and ones respectively) to normalize activations. It is |
|
54
|
|
|
expected that the user will adapt these during training in some |
|
55
|
|
|
fashion, independently of the training objective, e.g. by taking a |
|
56
|
|
|
moving average of minibatch-wise statistics. |
|
57
|
|
|
|
|
58
|
|
|
In order to *train* with batch normalization, one must obtain a |
|
59
|
|
|
training graph by transforming the original inference graph. See |
|
60
|
|
|
:func:`batch_normalize`. |
|
61
|
|
|
|
|
62
|
|
|
This Brick accepts `weights_init` and `biases_init` arguments but is |
|
63
|
|
|
*not* an instance of :class:`~blocks.bricks.Initializable`, and will |
|
64
|
|
|
therefore not receive pushed initialization config from any parent |
|
65
|
|
|
brick. In almost all cases, you will probably want to stick with the |
|
66
|
|
|
defaults (unit scale and zero shift), but you can explicitly pass one |
|
67
|
|
|
or both initializers to override this. |
|
68
|
|
|
|
|
69
|
|
|
This has the necessary properties to be inserted into a |
|
70
|
|
|
:class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case |
|
71
|
|
|
the `input_dim` should be omitted at construction, to be inferred from |
|
72
|
|
|
the layer below. |
|
73
|
|
|
|
|
74
|
|
|
""" |
|
75
|
|
|
@lazy(allocation=['input_dim']) |
|
76
|
|
|
def __init__(self, input_dim, broadcastable=None, |
|
77
|
|
|
save_memory=True, weights_init=None, |
|
78
|
|
|
biases_init=None, **kwargs): |
|
79
|
|
|
self.input_dim = input_dim |
|
80
|
|
|
self.broadcastable = broadcastable |
|
81
|
|
|
self.save_memory = save_memory |
|
82
|
|
|
self.weights_init = (Constant(1) if weights_init is None |
|
83
|
|
|
else weights_init) |
|
84
|
|
|
self.biases_init = (Constant(0) if biases_init is None |
|
85
|
|
|
else biases_init) |
|
86
|
|
|
super(BatchNormalization, self).__init__(**kwargs) |
|
87
|
|
|
|
|
88
|
|
|
@application(inputs=['input_'], outputs=['output']) |
|
89
|
|
|
def apply(self, input_, application_call): |
|
90
|
|
|
mean = self.population_mean.copy(name='population_offset') |
|
91
|
|
|
stdev = self.population_stdev.copy(name='population_divisor') |
|
92
|
|
|
|
|
93
|
|
|
def annotate(var, role): |
|
94
|
|
|
add_role(var, role) |
|
95
|
|
|
add_annotation(var, self) |
|
96
|
|
|
add_annotation(var, application_call) |
|
97
|
|
|
|
|
98
|
|
|
annotate(mean, BATCH_NORM_OFFSET) |
|
99
|
|
|
annotate(stdev, BATCH_NORM_DIVISOR) |
|
100
|
|
|
|
|
101
|
|
|
# Heavy lifting is done by the Theano utility function. |
|
102
|
|
|
normalized = bn.batch_normalization(input_, self.W, |
|
103
|
|
|
self.b, mean, stdev, |
|
104
|
|
|
mode=('low_mem' if self.save_memory |
|
105
|
|
|
else 'high_mem')) |
|
106
|
|
|
return normalized |
|
107
|
|
|
|
|
108
|
|
|
def _allocate(self): |
|
109
|
|
|
input_dim = ((self.input_dim,) |
|
110
|
|
|
if not isinstance(self.input_dim, collections.Sequence) |
|
111
|
|
|
else self.input_dim) |
|
112
|
|
|
broadcastable = (tuple(False for _ in range(len(input_dim))) |
|
113
|
|
|
if self.broadcastable is None else self.broadcastable) |
|
114
|
|
|
if len(input_dim) != len(broadcastable): |
|
115
|
|
|
raise ValueError("input_dim and broadcastable must be same length") |
|
116
|
|
|
var_dim = ((1,) + # batch axis |
|
117
|
|
|
tuple(1 if broadcast else dim for dim, broadcast in |
|
118
|
|
|
equizip(input_dim, broadcastable))) |
|
119
|
|
|
broadcastable = (True,) + broadcastable |
|
120
|
|
|
|
|
121
|
|
|
# "gamma", from the Ioffe & Szegedy manuscript. |
|
122
|
|
|
self._W = shared_floatx_nans(var_dim, name='batch_norm_scale', |
|
123
|
|
|
broadcastable=broadcastable) |
|
124
|
|
|
|
|
125
|
|
|
# "beta", from the Ioffe & Szegedy manuscript. |
|
126
|
|
|
self._b = shared_floatx_nans(var_dim, name='batch_norm_shift', |
|
127
|
|
|
broadcastable=broadcastable) |
|
128
|
|
|
add_role(self.W, WEIGHT) |
|
129
|
|
|
add_role(self.b, BIAS) |
|
130
|
|
|
self.parameters.append(self.W) |
|
131
|
|
|
self.parameters.append(self.b) |
|
132
|
|
|
|
|
133
|
|
|
# These aren't technically parameters, in that they should not be |
|
134
|
|
|
# learned using the same cost function as other model parameters. |
|
135
|
|
|
self.population_mean = shared_floatx_zeros(var_dim, |
|
136
|
|
|
name='population_mean', |
|
137
|
|
|
broadcastable=broadcastable) |
|
138
|
|
|
self.population_stdev = shared_floatx(numpy.ones(var_dim), |
|
139
|
|
|
name='population_stdev', |
|
140
|
|
|
broadcastable=broadcastable) |
|
141
|
|
|
add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN) |
|
142
|
|
|
add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV) |
|
143
|
|
|
|
|
144
|
|
|
@property |
|
145
|
|
|
def W(self): |
|
146
|
|
|
return self._W |
|
147
|
|
|
|
|
148
|
|
|
@property |
|
149
|
|
|
def b(self): |
|
150
|
|
|
return self._b |
|
151
|
|
|
|
|
152
|
|
|
def _initialize(self): |
|
153
|
|
|
self.biases_init.initialize(self.b, self.rng) |
|
154
|
|
|
self.weights_init.initialize(self.W, self.rng) |
|
155
|
|
|
|
|
156
|
|
|
# Needed for the Feedforward interface. |
|
157
|
|
|
@property |
|
158
|
|
|
def output_dim(self): |
|
159
|
|
|
return self.input_dim |
|
160
|
|
|
|
|
161
|
|
|
# The following properties allow for BatchNormalization bricks |
|
162
|
|
|
# to be used directly inside of a ConvolutionalSequence. |
|
163
|
|
|
@property |
|
164
|
|
|
def image_size(self): |
|
165
|
|
|
return self.input_dim[-2:] |
|
166
|
|
|
|
|
167
|
|
|
@image_size.setter |
|
168
|
|
|
def image_size(self, value): |
|
169
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
|
170
|
|
|
self.input_dim = (None,) + tuple(value) |
|
171
|
|
|
else: |
|
172
|
|
|
self.input_dim = (self.input_dim[0],) + tuple(value) |
|
173
|
|
|
|
|
174
|
|
|
@property |
|
175
|
|
|
def num_channels(self): |
|
176
|
|
|
return self.input_dim[0] |
|
177
|
|
|
|
|
178
|
|
|
@num_channels.setter |
|
179
|
|
|
def num_channels(self, value): |
|
180
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
|
181
|
|
|
self.input_dim = (value,) + (None, None) |
|
182
|
|
|
else: |
|
183
|
|
|
self.input_dim = (value,) + self.input_dim[-2:] |
|
184
|
|
|
|
|
185
|
|
|
def get_dim(self, name): |
|
186
|
|
|
if name in ('input', 'output'): |
|
187
|
|
|
return self.input_dim |
|
188
|
|
|
else: |
|
189
|
|
|
raise KeyError |
|
190
|
|
|
|
|
191
|
|
|
@property |
|
192
|
|
|
def num_output_channels(self): |
|
193
|
|
|
return self.num_channels |
|
194
|
|
|
|
|
195
|
|
|
|
|
196
|
|
|
class SpatialBatchNormalization(BatchNormalization): |
|
197
|
|
|
"""Convenient subclass for batch normalization across spatial inputs. |
|
198
|
|
|
|
|
199
|
|
|
Parameters |
|
200
|
|
|
---------- |
|
201
|
|
|
input_dim : int or tuple |
|
202
|
|
|
The input size of a single example. Must be length at least 2. |
|
203
|
|
|
It's assumed that the first axis of this tuple is a "channels" |
|
204
|
|
|
axis, which should not be summed over, and all remaining |
|
205
|
|
|
dimensions are spatial dimensions. |
|
206
|
|
|
|
|
207
|
|
|
Notes |
|
208
|
|
|
----- |
|
209
|
|
|
See :class:`BatchNormalization` for more details (and additional |
|
210
|
|
|
keyword arguments). |
|
211
|
|
|
|
|
212
|
|
|
""" |
|
213
|
|
|
@lazy(allocation=['input_dim']) |
|
214
|
|
|
def __init__(self, input_dim, **kwargs): |
|
215
|
|
|
if not isinstance(input_dim, |
|
216
|
|
|
collections.Sequence) or len(input_dim) < 2: |
|
217
|
|
|
raise ValueError('expected input_dim to be length >= 2 ' |
|
218
|
|
|
'(channels, height, width)') |
|
219
|
|
|
broadcastable = (False,) + ((True,) * (len(input_dim) - 1)) |
|
220
|
|
|
kwargs.setdefault('broadcastable', broadcastable) |
|
221
|
|
|
super(SpatialBatchNormalization, self).__init__(input_dim, **kwargs) |
|
222
|
|
|
|
|
223
|
|
|
|
|
224
|
|
|
class BatchNormalizedMLP(MLP): |
|
225
|
|
|
"""Convenient subclass for building an MLP with batch normalization. |
|
226
|
|
|
|
|
227
|
|
|
Notes |
|
228
|
|
|
----- |
|
229
|
|
|
All parameters are the same as :class:`~blocks.bricks.MLP`. Each |
|
230
|
|
|
activation brick is wrapped in a :class:`~blocks.bricks.Sequence` |
|
231
|
|
|
containing an appropriate :class:`BatchNormalization` brick and |
|
232
|
|
|
the activation that follows it. |
|
233
|
|
|
|
|
234
|
|
|
By default, the contained :class:`~blocks.bricks.Linear` bricks will |
|
235
|
|
|
not contain any biases, as they could be canceled out by the biases |
|
236
|
|
|
in the :class:`BatchNormalization` bricks being added. Pass |
|
237
|
|
|
`use_bias` with a value of `True` if you really want this for some |
|
238
|
|
|
reason. |
|
239
|
|
|
|
|
240
|
|
|
""" |
|
241
|
|
|
@lazy(allocation=['dims']) |
|
242
|
|
|
def __init__(self, activations, dims, *args, **kwargs): |
|
243
|
|
|
activations = [Sequence([BatchNormalization().apply, act.apply], |
|
244
|
|
|
name='batch_norm_activation_{}'.format(i)) |
|
245
|
|
|
for i, act in enumerate(activations)] |
|
246
|
|
|
# Batch normalization bricks incorporate a bias, so there's no |
|
247
|
|
|
# need for our Linear bricks to have them. |
|
248
|
|
|
kwargs.setdefault('use_bias', False) |
|
249
|
|
|
super(BatchNormalizedMLP, self).__init__(activations, dims, *args, |
|
250
|
|
|
**kwargs) |
|
251
|
|
|
|
|
252
|
|
|
def _push_allocation_config(self): |
|
253
|
|
|
super(BatchNormalizedMLP, self)._push_allocation_config() |
|
254
|
|
|
# Do the extra allocation pushing for the BatchNormalization |
|
255
|
|
|
# bricks. They need as their input dimension the output dimension |
|
256
|
|
|
# of each linear transformation. Exclude the first dimension, |
|
257
|
|
|
# which is the input dimension. |
|
258
|
|
|
for act, dim in equizip(self.activations, self.dims[1:]): |
|
259
|
|
|
act.children[0].input_dim = dim |
|
260
|
|
|
|
|
261
|
|
|
|
|
262
|
|
|
def batch_normalize(computation_graph, epsilon=1e-4): |
|
263
|
|
|
"""Activate batch normalization in a graph. |
|
264
|
|
|
|
|
265
|
|
|
Parameters |
|
266
|
|
|
---------- |
|
267
|
|
|
computation_graph : instance of :class:`ComputationGraph` |
|
268
|
|
|
The computation graph containing :class:`BatchNormalization` |
|
269
|
|
|
brick applications. |
|
270
|
|
|
epsilon : float, optional |
|
271
|
|
|
The stabilizing constant for the minibatch standard deviation |
|
272
|
|
|
computation. Added to the variance inside the square root, as |
|
273
|
|
|
in the batch normalization paper. |
|
274
|
|
|
|
|
275
|
|
|
Returns |
|
276
|
|
|
------- |
|
277
|
|
|
batch_normed_computation_graph : instance of :class:`ComputationGraph` |
|
278
|
|
|
The computation graph, with :class:`BatchNormalization` |
|
279
|
|
|
applications transformed to use minibatch statistics instead |
|
280
|
|
|
of accumulated population statistics. |
|
281
|
|
|
|
|
282
|
|
|
Notes |
|
283
|
|
|
----- |
|
284
|
|
|
Assumes the minibatch axis is 0. Other axes are unsupported at |
|
285
|
|
|
this time. |
|
286
|
|
|
|
|
287
|
|
|
""" |
|
288
|
|
|
|
|
289
|
|
|
# Create filters for variables involved in a batch normalization brick |
|
290
|
|
|
# application. |
|
291
|
|
|
def make_variable_filter(role): |
|
292
|
|
|
return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
|
293
|
|
|
|
|
294
|
|
|
mean_filter, stdev_filter, input_filter = map(make_variable_filter, |
|
295
|
|
|
[BATCH_NORM_OFFSET, |
|
296
|
|
|
BATCH_NORM_DIVISOR, INPUT]) |
|
297
|
|
|
|
|
298
|
|
|
# Group means, standard deviations, and inputs into dicts indexed by |
|
299
|
|
|
# application call. |
|
300
|
|
|
def get_application_call_dict(variable_filter): |
|
301
|
|
|
return collections.OrderedDict((get_application_call(v), v) for v in |
|
302
|
|
|
variable_filter(computation_graph)) |
|
303
|
|
|
|
|
304
|
|
|
means, stdevs, inputs = map(get_application_call_dict, |
|
305
|
|
|
[mean_filter, stdev_filter, input_filter]) |
|
306
|
|
|
|
|
307
|
|
|
assert (set(means.keys()) == set(stdevs.keys()) and |
|
308
|
|
|
set(means.keys()) == set(inputs.keys())) |
|
309
|
|
|
assert set(means.values()).isdisjoint(stdevs.values()) |
|
310
|
|
|
|
|
311
|
|
|
replacements = [] |
|
312
|
|
|
# Perform replacement for each application call. |
|
313
|
|
|
for application_call in means: |
|
314
|
|
|
axes = tuple(i for i, b in enumerate(means[application_call] |
|
315
|
|
|
.broadcastable) if b) |
|
316
|
|
|
minibatch_mean = inputs[application_call].mean(axis=axes, |
|
317
|
|
|
keepdims=True) |
|
318
|
|
|
minibatch_mean.name = 'minibatch_offset' |
|
319
|
|
|
# Stabilize in the same way as the batch normalization manuscript. |
|
320
|
|
|
minibatch_std = tensor.sqrt(tensor.var(inputs[application_call], |
|
321
|
|
|
axis=axes, keepdims=True) |
|
322
|
|
|
+ epsilon) |
|
323
|
|
|
minibatch_std.name = 'minibatch_divisor' |
|
324
|
|
|
|
|
325
|
|
|
def prepare_replacement(old, new, role, application_call): |
|
326
|
|
|
"""Add roles and tags to replaced variables.""" |
|
327
|
|
|
add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE) |
|
328
|
|
|
add_role(new, role) |
|
329
|
|
|
add_annotation(new, application_call) |
|
330
|
|
|
add_annotation(new, application_call.application.brick) |
|
331
|
|
|
new.tag.replacement_of = old |
|
332
|
|
|
replacements.append((old, new)) |
|
333
|
|
|
|
|
334
|
|
|
prepare_replacement(means[application_call], minibatch_mean, |
|
335
|
|
|
BATCH_NORM_OFFSET, application_call) |
|
336
|
|
|
prepare_replacement(stdevs[application_call], minibatch_std, |
|
337
|
|
|
BATCH_NORM_DIVISOR, application_call) |
|
338
|
|
|
|
|
339
|
|
|
return computation_graph.replace(replacements) |
|
340
|
|
|
|
|
341
|
|
|
|
|
342
|
|
|
def population_to_minibatch(bn_graph): |
|
343
|
|
|
"""Get a mapping from population statistics to minibatch estimates. |
|
344
|
|
|
|
|
345
|
|
|
Parameters |
|
346
|
|
|
---------- |
|
347
|
|
|
bn_graph : :class:`~blocks.graph.ComputationGraph` |
|
348
|
|
|
Graph returned by :func:`batch_normalize`. |
|
349
|
|
|
|
|
350
|
|
|
Returns |
|
351
|
|
|
------- |
|
352
|
|
|
OrderedDict |
|
353
|
|
|
A mapping from variables representing population statistics |
|
354
|
|
|
to the corresponding minibatch estimate that replaces it in |
|
355
|
|
|
the batch-normalized graph. |
|
356
|
|
|
|
|
357
|
|
|
""" |
|
358
|
|
|
variables = VariableFilter(roles=[BATCH_NORM_MINIBATCH_ESTIMATE])(bn_graph) |
|
359
|
|
|
return collections.OrderedDict((v.replacement_of, v) for v in variables) |
|
360
|
|
|
|