1
|
|
|
import collections |
2
|
|
|
from functools import partial |
3
|
|
|
|
4
|
|
|
import numpy |
5
|
|
|
from picklable_itertools.extras import equizip |
6
|
|
|
import theano |
7
|
|
|
from theano import tensor |
8
|
|
|
from theano.tensor.nnet import bn |
9
|
|
|
|
10
|
|
|
from ..graph import add_annotation |
11
|
|
|
from ..initialization import Constant |
12
|
|
|
from ..roles import (BATCH_NORM_POPULATION_MEAN, |
13
|
|
|
BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET, |
14
|
|
|
BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE, |
15
|
|
|
BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER, |
16
|
|
|
add_role) |
17
|
|
|
from ..utils import (shared_floatx_zeros, shared_floatx, |
18
|
|
|
shared_floatx_nans) |
19
|
|
|
from .base import lazy, application |
20
|
|
|
from .sequences import Sequence, Feedforward, MLP |
21
|
|
|
from .interfaces import RNGMixin |
22
|
|
|
|
23
|
|
|
|
24
|
|
|
def _add_batch_axis(var): |
25
|
|
|
"""Prepend a singleton axis to a TensorVariable and name it.""" |
26
|
|
|
new_var = new_var = tensor.shape_padleft(var) |
27
|
|
|
new_var.name = 'shape_padleft({})'.format(var.name) |
28
|
|
|
return new_var |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
def _add_role_and_annotate(var, role, annotations=()): |
32
|
|
|
"""Add a role and zero or more annotations to a variable.""" |
33
|
|
|
add_role(var, role) |
34
|
|
|
for annotation in annotations: |
35
|
|
|
add_annotation(var, annotation) |
36
|
|
|
|
37
|
|
|
|
38
|
|
|
class BatchNormalization(RNGMixin, Feedforward): |
39
|
|
|
r"""Normalizes activations, parameterizes a scale and shift. |
40
|
|
|
|
41
|
|
|
Parameters |
42
|
|
|
---------- |
43
|
|
|
input_dim : int or tuple |
44
|
|
|
Shape of a single input example. It is assumed that a batch axis |
45
|
|
|
will be prepended to this. |
46
|
|
|
broadcastable : tuple, optional |
47
|
|
|
Tuple the same length as `input_dim` which specifies which of the |
48
|
|
|
per-example axes should be averaged over to compute means and |
49
|
|
|
standard deviations. For example, in order to normalize over all |
50
|
|
|
spatial locations in a `(batch_index, channels, height, width)` |
51
|
|
|
image, pass `(False, True, True)`. |
52
|
|
|
conserve_memory : bool, optional |
53
|
|
|
Use an implementation that stores less intermediate state and |
54
|
|
|
therefore uses less memory, at the expense of 5-10% speed. Default |
55
|
|
|
is `True`. |
56
|
|
|
epsilon : float, optional |
57
|
|
|
The stabilizing constant for the minibatch standard deviation |
58
|
|
|
computation (when the brick is run in training mode). |
59
|
|
|
Added to the variance inside the square root, as in the |
60
|
|
|
batch normalization paper. |
61
|
|
|
scale_init : object, optional |
62
|
|
|
Initialization object to use for the learned scaling parameter |
63
|
|
|
($\\gamma$ in [BN]_). By default, uses constant initialization |
64
|
|
|
of 1. |
65
|
|
|
shift_init : object, optional |
66
|
|
|
Initialization object to use for the learned shift parameter |
67
|
|
|
($\\beta$ in [BN]_). By default, uses constant initialization of 0. |
68
|
|
|
mean_only : bool, optional |
69
|
|
|
Perform "mean-only" batch normalization as described in [SK2016]_. |
70
|
|
|
|
71
|
|
|
Notes |
72
|
|
|
----- |
73
|
|
|
In order for trained models to behave sensibly immediately upon |
74
|
|
|
upon deserialization, by default, this brick runs in *inference* mode, |
75
|
|
|
using a population mean and population standard deviation (initialized |
76
|
|
|
to zeros and ones respectively) to normalize activations. It is |
77
|
|
|
expected that the user will adapt these during training in some |
78
|
|
|
fashion, independently of the training objective, e.g. by taking a |
79
|
|
|
moving average of minibatch-wise statistics. |
80
|
|
|
|
81
|
|
|
In order to *train* with batch normalization, one must obtain a |
82
|
|
|
training graph by transforming the original inference graph. See |
83
|
|
|
:func:`~blocks.graph.apply_batch_normalization` for a routine to |
84
|
|
|
transform graphs, and :func:`~blocks.graph.batch_normalization` |
85
|
|
|
for a context manager that may enable shorter compile times |
86
|
|
|
(every instance of :class:`BatchNormalization` is itself a context |
87
|
|
|
manager, entry into which causes applications to be in minibatch |
88
|
|
|
"training" mode, however it is usually more convenient to use |
89
|
|
|
:func:`~blocks.graph.batch_normalization` to enable this behaviour |
90
|
|
|
for all of your graph's :class:`BatchNormalization` bricks at once). |
91
|
|
|
|
92
|
|
|
Note that training in inference mode should be avoided, as this |
93
|
|
|
brick introduces scales and shift parameters (tagged with the |
94
|
|
|
`PARAMETER` role) that, in the absence of batch normalization, |
95
|
|
|
usually makes things unstable. If you must do this, filter for and |
96
|
|
|
remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER` |
97
|
|
|
from the list of parameters you are training, and this brick should |
98
|
|
|
behave as a (somewhat expensive) no-op. |
99
|
|
|
|
100
|
|
|
This Brick accepts `scale_init` and `shift_init` arguments but is |
101
|
|
|
*not* an instance of :class:`~blocks.bricks.Initializable`, and will |
102
|
|
|
therefore not receive pushed initialization config from any parent |
103
|
|
|
brick. In almost all cases, you will probably want to stick with the |
104
|
|
|
defaults (unit scale and zero offset), but you can explicitly pass one |
105
|
|
|
or both initializers to override this. |
106
|
|
|
|
107
|
|
|
This has the necessary properties to be inserted into a |
108
|
|
|
:class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case |
109
|
|
|
the `input_dim` should be omitted at construction, to be inferred from |
110
|
|
|
the layer below. |
111
|
|
|
|
112
|
|
|
|
113
|
|
|
.. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization: |
114
|
|
|
accelerating deep network training by reducing internal covariate |
115
|
|
|
shift*. ICML (2015), pp. 448-456. |
116
|
|
|
|
117
|
|
|
.. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight |
118
|
|
|
normalization: a simple reparameterization to accelerate training |
119
|
|
|
of deep neural networks*. arXiv 1602.07868. |
120
|
|
|
|
121
|
|
|
""" |
122
|
|
|
@lazy(allocation=['input_dim']) |
123
|
|
|
def __init__(self, input_dim, broadcastable=None, |
124
|
|
|
conserve_memory=True, epsilon=1e-4, scale_init=None, |
125
|
|
|
shift_init=None, mean_only=False, **kwargs): |
126
|
|
|
self.input_dim = input_dim |
127
|
|
|
self.broadcastable = broadcastable |
128
|
|
|
self.conserve_memory = conserve_memory |
129
|
|
|
self.epsilon = epsilon |
130
|
|
|
self.scale_init = (Constant(1) if scale_init is None |
131
|
|
|
else scale_init) |
132
|
|
|
self.shift_init = (Constant(0) if shift_init is None |
133
|
|
|
else shift_init) |
134
|
|
|
self.mean_only = mean_only |
135
|
|
|
self._training_mode = [] |
136
|
|
|
super(BatchNormalization, self).__init__(**kwargs) |
137
|
|
|
|
138
|
|
|
@application(inputs=['input_'], outputs=['output']) |
139
|
|
|
def apply(self, input_, application_call): |
140
|
|
|
if self._training_mode: |
141
|
|
|
mean, stdev = self._compute_training_statistics(input_) |
142
|
|
|
else: |
143
|
|
|
mean, stdev = self._prepare_population_statistics() |
144
|
|
|
# Useful for filtration of calls that were already made in |
145
|
|
|
# training mode when doing graph transformations. |
146
|
|
|
# Very important to cast to bool, as self._training_mode is |
147
|
|
|
# normally a list (to support nested context managers), which would |
148
|
|
|
# otherwise get passed by reference and be remotely mutated. |
149
|
|
|
application_call.metadata['training_mode'] = bool(self._training_mode) |
150
|
|
|
# Useful for retrieving a list of updates for population |
151
|
|
|
# statistics. Ditch the broadcastable first axis, though, to |
152
|
|
|
# make it the same dimensions as the population mean and stdev |
153
|
|
|
# shared variables. |
154
|
|
|
application_call.metadata['offset'] = mean[0] |
155
|
|
|
application_call.metadata['divisor'] = stdev[0] |
156
|
|
|
# Give these quantities roles in the graph. |
157
|
|
|
_add_role_and_annotate(mean, BATCH_NORM_OFFSET, |
158
|
|
|
[self, application_call]) |
159
|
|
|
if self.mean_only: |
160
|
|
|
scale = tensor.ones_like(self.shift) |
161
|
|
|
stdev = tensor.ones_like(mean) |
162
|
|
|
else: |
163
|
|
|
scale = self.scale |
164
|
|
|
# The annotation/role information is useless if it's a constant. |
165
|
|
|
_add_role_and_annotate(stdev, BATCH_NORM_DIVISOR, |
166
|
|
|
[self, application_call]) |
167
|
|
|
shift = _add_batch_axis(self.shift) |
168
|
|
|
scale = _add_batch_axis(scale) |
169
|
|
|
# Heavy lifting is done by the Theano utility function. |
170
|
|
|
normalized = bn.batch_normalization(input_, scale, shift, mean, stdev, |
171
|
|
|
mode=('low_mem' |
172
|
|
|
if self.conserve_memory |
173
|
|
|
else 'high_mem')) |
174
|
|
|
return normalized |
175
|
|
|
|
176
|
|
|
def __enter__(self): |
177
|
|
|
self._training_mode.append(True) |
178
|
|
|
|
179
|
|
|
def __exit__(self, *exc_info): |
180
|
|
|
self._training_mode.pop() |
181
|
|
|
|
182
|
|
|
def _compute_training_statistics(self, input_): |
183
|
|
|
axes = (0,) + tuple((i + 1) for i, b in |
184
|
|
|
enumerate(self.population_mean.broadcastable) |
185
|
|
|
if b) |
186
|
|
|
mean = input_.mean(axis=axes, keepdims=True) |
187
|
|
|
assert mean.broadcastable[1:] == self.population_mean.broadcastable |
188
|
|
|
add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE) |
189
|
|
|
if self.mean_only: |
190
|
|
|
stdev = tensor.ones_like(mean) |
191
|
|
|
else: |
192
|
|
|
stdev = tensor.sqrt(tensor.var(input_, axis=axes, keepdims=True) + |
193
|
|
|
numpy.cast[theano.config.floatX](self.epsilon)) |
194
|
|
|
assert (stdev.broadcastable[1:] == |
195
|
|
|
self.population_stdev.broadcastable) |
196
|
|
|
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE) |
197
|
|
|
return mean, stdev |
198
|
|
|
|
199
|
|
|
def _prepare_population_statistics(self): |
200
|
|
|
mean = _add_batch_axis(self.population_mean) |
201
|
|
|
if self.mean_only: |
202
|
|
|
stdev = tensor.ones_like(self.population_mean) |
203
|
|
|
else: |
204
|
|
|
stdev = self.population_stdev |
205
|
|
|
stdev = _add_batch_axis(stdev) |
206
|
|
|
return mean, stdev |
207
|
|
|
|
208
|
|
|
def _allocate(self): |
209
|
|
|
input_dim = ((self.input_dim,) |
210
|
|
|
if not isinstance(self.input_dim, collections.Sequence) |
211
|
|
|
else self.input_dim) |
212
|
|
|
broadcastable = (tuple(False for _ in input_dim) |
213
|
|
|
if self.broadcastable is None else self.broadcastable) |
214
|
|
|
if len(input_dim) != len(broadcastable): |
215
|
|
|
raise ValueError("input_dim and broadcastable must be same length") |
216
|
|
|
var_dim = tuple(1 if broadcast else dim for dim, broadcast in |
217
|
|
|
equizip(input_dim, broadcastable)) |
218
|
|
|
broadcastable = broadcastable |
219
|
|
|
|
220
|
|
|
# "beta", from the Ioffe & Szegedy manuscript. |
221
|
|
|
self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift', |
222
|
|
|
broadcastable=broadcastable) |
223
|
|
|
add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER) |
224
|
|
|
self.parameters.append(self.shift) |
225
|
|
|
|
226
|
|
|
# These aren't technically parameters, in that they should not be |
227
|
|
|
# learned using the same cost function as other model parameters. |
228
|
|
|
self.population_mean = shared_floatx_zeros(var_dim, |
229
|
|
|
name='population_mean', |
230
|
|
|
broadcastable=broadcastable) |
231
|
|
|
add_role(self.population_mean, BATCH_NORM_POPULATION_MEAN) |
232
|
|
|
|
233
|
|
|
# Normally these would get annotated by an AnnotatingList, but they |
234
|
|
|
# aren't in self.parameters. |
235
|
|
|
add_annotation(self.population_mean, self) |
236
|
|
|
|
237
|
|
|
if not self.mean_only: |
238
|
|
|
# "gamma", from the Ioffe & Szegedy manuscript. |
239
|
|
|
self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale', |
240
|
|
|
broadcastable=broadcastable) |
241
|
|
|
|
242
|
|
|
add_role(self.scale, BATCH_NORM_SCALE_PARAMETER) |
243
|
|
|
self.parameters.append(self.scale) |
244
|
|
|
|
245
|
|
|
self.population_stdev = shared_floatx(numpy.ones(var_dim), |
246
|
|
|
name='population_stdev', |
247
|
|
|
broadcastable=broadcastable) |
248
|
|
|
add_role(self.population_stdev, BATCH_NORM_POPULATION_STDEV) |
249
|
|
|
add_annotation(self.population_stdev, self) |
250
|
|
|
|
251
|
|
|
def _initialize(self): |
252
|
|
|
self.shift_init.initialize(self.shift, self.rng) |
253
|
|
|
if not self.mean_only: |
254
|
|
|
self.scale_init.initialize(self.scale, self.rng) |
255
|
|
|
|
256
|
|
|
# Needed for the Feedforward interface. |
257
|
|
|
@property |
258
|
|
|
def output_dim(self): |
259
|
|
|
return self.input_dim |
260
|
|
|
|
261
|
|
|
# The following properties allow for BatchNormalization bricks |
262
|
|
|
# to be used directly inside of a ConvolutionalSequence. |
263
|
|
|
@property |
264
|
|
|
def image_size(self): |
265
|
|
|
return self.input_dim[-2:] |
266
|
|
|
|
267
|
|
|
@image_size.setter |
268
|
|
|
def image_size(self, value): |
269
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
270
|
|
|
self.input_dim = (None,) + tuple(value) |
271
|
|
|
else: |
272
|
|
|
self.input_dim = (self.input_dim[0],) + tuple(value) |
273
|
|
|
|
274
|
|
|
@property |
275
|
|
|
def num_channels(self): |
276
|
|
|
return self.input_dim[0] |
277
|
|
|
|
278
|
|
|
@num_channels.setter |
279
|
|
|
def num_channels(self, value): |
280
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
281
|
|
|
self.input_dim = (value,) + (None, None) |
282
|
|
|
else: |
283
|
|
|
self.input_dim = (value,) + self.input_dim[-2:] |
284
|
|
|
|
285
|
|
|
def get_dim(self, name): |
286
|
|
|
if name in ('input', 'output'): |
287
|
|
|
return self.input_dim |
288
|
|
|
else: |
289
|
|
|
raise KeyError |
290
|
|
|
|
291
|
|
|
@property |
292
|
|
|
def num_output_channels(self): |
293
|
|
|
return self.num_channels |
294
|
|
|
|
295
|
|
|
|
296
|
|
|
class SpatialBatchNormalization(BatchNormalization): |
297
|
|
|
"""Convenient subclass for batch normalization across spatial inputs. |
298
|
|
|
|
299
|
|
|
Parameters |
300
|
|
|
---------- |
301
|
|
|
input_dim : int or tuple |
302
|
|
|
The input size of a single example. Must be length at least 2. |
303
|
|
|
It's assumed that the first axis of this tuple is a "channels" |
304
|
|
|
axis, which should not be summed over, and all remaining |
305
|
|
|
dimensions are spatial dimensions. |
306
|
|
|
|
307
|
|
|
Notes |
308
|
|
|
----- |
309
|
|
|
See :class:`BatchNormalization` for more details (and additional |
310
|
|
|
keyword arguments). |
311
|
|
|
|
312
|
|
|
""" |
313
|
|
|
def _allocate(self): |
314
|
|
|
if not isinstance(self.input_dim, |
315
|
|
|
collections.Sequence) or len(self.input_dim) < 2: |
316
|
|
|
raise ValueError('expected input_dim to be length >= 2 ' |
317
|
|
|
'e.g. (channels, height, width)') |
318
|
|
|
self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1)) |
319
|
|
|
super(SpatialBatchNormalization, self)._allocate() |
320
|
|
|
|
321
|
|
|
|
322
|
|
|
class BatchNormalizedMLP(MLP): |
323
|
|
|
"""Convenient subclass for building an MLP with batch normalization. |
324
|
|
|
|
325
|
|
|
Parameters |
326
|
|
|
---------- |
327
|
|
|
conserve_memory : bool, optional |
328
|
|
|
See :class:`BatchNormalization`. |
329
|
|
|
mean_only : bool, optional |
330
|
|
|
See :class:`BatchNormalization`. |
331
|
|
|
|
332
|
|
|
Notes |
333
|
|
|
----- |
334
|
|
|
All other parameters are the same as :class:`~blocks.bricks.MLP`. Each |
335
|
|
|
activation brick is wrapped in a :class:`~blocks.bricks.Sequence` |
336
|
|
|
containing an appropriate :class:`BatchNormalization` brick and |
337
|
|
|
the activation that follows it. |
338
|
|
|
|
339
|
|
|
By default, the contained :class:`~blocks.bricks.Linear` bricks will |
340
|
|
|
not contain any biases, as they could be canceled out by the biases |
341
|
|
|
in the :class:`BatchNormalization` bricks being added. Pass |
342
|
|
|
`use_bias` with a value of `True` if you really want this for some |
343
|
|
|
reason. |
344
|
|
|
|
345
|
|
|
""" |
346
|
|
|
@lazy(allocation=['dims']) |
347
|
|
|
def __init__(self, activations, dims, *args, **kwargs): |
348
|
|
|
self._conserve_memory = kwargs.pop('conserve_memory', True) |
349
|
|
|
self._mean_only = kwargs.pop('mean_only', False) |
350
|
|
|
activations = [ |
351
|
|
|
Sequence([ |
352
|
|
|
BatchNormalization(conserve_memory=self._conserve_memory, |
353
|
|
|
mean_only=self._mean_only).apply, |
354
|
|
|
act.apply |
355
|
|
|
], name='batch_norm_activation_{}'.format(i)) |
356
|
|
|
for i, act in enumerate(activations) |
357
|
|
|
] |
358
|
|
|
# Batch normalization bricks incorporate a bias, so there's no |
359
|
|
|
# need for our Linear bricks to have them. |
360
|
|
|
kwargs.setdefault('use_bias', False) |
361
|
|
|
super(BatchNormalizedMLP, self).__init__(activations, dims, *args, |
362
|
|
|
**kwargs) |
363
|
|
|
|
364
|
|
|
def _nested_brick_property_getter(self, property_name): |
365
|
|
|
return getattr(self, '_' + property_name) |
366
|
|
|
|
367
|
|
|
def _nested_brick_property_setter(self, value, property_name): |
368
|
|
|
setattr(self, '_' + property_name, value) |
369
|
|
|
for act in self.activations: |
370
|
|
|
assert isinstance(act.children[0], BatchNormalization) |
371
|
|
|
setattr(act.children[0], property_name, value) |
372
|
|
|
|
373
|
|
|
conserve_memory = property(partial(_nested_brick_property_getter, |
374
|
|
|
property_name='conserve_memory'), |
375
|
|
|
partial(_nested_brick_property_setter, |
376
|
|
|
property_name='conserve_memory')) |
377
|
|
|
|
378
|
|
|
mean_only = property(partial(_nested_brick_property_getter, |
379
|
|
|
property_name='mean_only'), |
380
|
|
|
partial(_nested_brick_property_setter, |
381
|
|
|
property_name='mean_only')) |
382
|
|
|
|
383
|
|
|
def _push_allocation_config(self): |
384
|
|
|
super(BatchNormalizedMLP, self)._push_allocation_config() |
385
|
|
|
# Do the extra allocation pushing for the BatchNormalization |
386
|
|
|
# bricks. They need as their input dimension the output dimension |
387
|
|
|
# of each linear transformation. Exclude the first dimension, |
388
|
|
|
# which is the input dimension. |
389
|
|
|
for act, dim in equizip(self.activations, self.dims[1:]): |
390
|
|
|
assert isinstance(act.children[0], BatchNormalization) |
391
|
|
|
act.children[0].input_dim = dim |
392
|
|
|
|