|
1
|
|
|
import collections |
|
2
|
|
|
from functools import partial |
|
3
|
|
|
|
|
4
|
|
|
import numpy |
|
5
|
|
|
from picklable_itertools.extras import equizip |
|
6
|
|
|
import theano |
|
7
|
|
|
from theano import tensor |
|
8
|
|
|
from theano.tensor.nnet import bn |
|
9
|
|
|
|
|
10
|
|
|
from ..graph import add_annotation |
|
11
|
|
|
from ..initialization import Constant |
|
12
|
|
|
from ..roles import (BATCH_NORM_POPULATION_MEAN, |
|
13
|
|
|
BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET, |
|
14
|
|
|
BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE, |
|
15
|
|
|
BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER, |
|
16
|
|
|
add_role) |
|
17
|
|
|
from ..utils import shared_floatx, shared_floatx_nans, is_shared_variable |
|
18
|
|
|
from .base import lazy, application |
|
19
|
|
|
from .sequences import Sequence, Feedforward, MLP |
|
20
|
|
|
from .interfaces import RNGMixin |
|
21
|
|
|
|
|
22
|
|
|
|
|
23
|
|
|
def _add_batch_axis(var): |
|
24
|
|
|
"""Prepend a singleton axis to a TensorVariable and name it.""" |
|
25
|
|
|
new_var = new_var = tensor.shape_padleft(var) |
|
26
|
|
|
new_var.name = 'shape_padleft({})'.format(var.name) |
|
27
|
|
|
return new_var |
|
28
|
|
|
|
|
29
|
|
|
|
|
30
|
|
|
def _add_role_and_annotate(var, role, annotations=()): |
|
31
|
|
|
"""Add a role and zero or more annotations to a variable.""" |
|
32
|
|
|
add_role(var, role) |
|
33
|
|
|
for annotation in annotations: |
|
34
|
|
|
add_annotation(var, annotation) |
|
35
|
|
|
|
|
36
|
|
|
|
|
37
|
|
|
class BatchNormalization(RNGMixin, Feedforward): |
|
38
|
|
|
r"""Normalizes activations, parameterizes a scale and shift. |
|
39
|
|
|
|
|
40
|
|
|
Parameters |
|
41
|
|
|
---------- |
|
42
|
|
|
input_dim : int or tuple |
|
43
|
|
|
Shape of a single input example. It is assumed that a batch axis |
|
44
|
|
|
will be prepended to this. |
|
45
|
|
|
broadcastable : tuple, optional |
|
46
|
|
|
Tuple of the same length as `input_dim` which specifies which of |
|
47
|
|
|
the per-example axes should be averaged over to compute means and |
|
48
|
|
|
standard deviations. For example, in order to normalize over all |
|
49
|
|
|
spatial locations in a `(batch_index, channels, height, width)` |
|
50
|
|
|
image, pass `(False, True, True)`. The batch axis is always |
|
51
|
|
|
averaged out. |
|
52
|
|
|
conserve_memory : bool, optional |
|
53
|
|
|
Use an implementation that stores less intermediate state and |
|
54
|
|
|
therefore uses less memory, at the expense of 5-10% speed. Default |
|
55
|
|
|
is `True`. |
|
56
|
|
|
epsilon : float, optional |
|
57
|
|
|
The stabilizing constant for the minibatch standard deviation |
|
58
|
|
|
computation (when the brick is run in training mode). |
|
59
|
|
|
Added to the variance inside the square root, as in the |
|
60
|
|
|
batch normalization paper. |
|
61
|
|
|
scale_init : object, optional |
|
62
|
|
|
Initialization object to use for the learned scaling parameter |
|
63
|
|
|
($\\gamma$ in [BN]_). By default, uses constant initialization |
|
64
|
|
|
of 1. |
|
65
|
|
|
shift_init : object, optional |
|
66
|
|
|
Initialization object to use for the learned shift parameter |
|
67
|
|
|
($\\beta$ in [BN]_). By default, uses constant initialization of 0. |
|
68
|
|
|
mean_only : bool, optional |
|
69
|
|
|
Perform "mean-only" batch normalization as described in [SK2016]_. |
|
70
|
|
|
learn_scale : bool, optional |
|
71
|
|
|
Whether to include a learned scale parameter ($\\gamma$ in [BN]_) |
|
72
|
|
|
in this brick. Default is `True`. Has no effect if `mean_only` is |
|
73
|
|
|
`True` (i.e. a scale parameter is never learned in mean-only mode). |
|
74
|
|
|
learn_shift : bool, optional |
|
75
|
|
|
Whether to include a learned shift parameter ($\\beta$ in [BN]_) |
|
76
|
|
|
in this brick. Default is `True`. |
|
77
|
|
|
|
|
78
|
|
|
Notes |
|
79
|
|
|
----- |
|
80
|
|
|
In order for trained models to behave sensibly immediately upon |
|
81
|
|
|
upon deserialization, by default, this brick runs in *inference* mode, |
|
82
|
|
|
using a population mean and population standard deviation (initialized |
|
83
|
|
|
to zeros and ones respectively) to normalize activations. It is |
|
84
|
|
|
expected that the user will adapt these during training in some |
|
85
|
|
|
fashion, independently of the training objective, e.g. by taking a |
|
86
|
|
|
moving average of minibatch-wise statistics. |
|
87
|
|
|
|
|
88
|
|
|
In order to *train* with batch normalization, one must obtain a |
|
89
|
|
|
training graph by transforming the original inference graph. See |
|
90
|
|
|
:func:`~blocks.graph.apply_batch_normalization` for a routine to |
|
91
|
|
|
transform graphs, and :func:`~blocks.graph.batch_normalization` |
|
92
|
|
|
for a context manager that may enable shorter compile times |
|
93
|
|
|
(every instance of :class:`BatchNormalization` is itself a context |
|
94
|
|
|
manager, entry into which causes applications to be in minibatch |
|
95
|
|
|
"training" mode, however it is usually more convenient to use |
|
96
|
|
|
:func:`~blocks.graph.batch_normalization` to enable this behaviour |
|
97
|
|
|
for all of your graph's :class:`BatchNormalization` bricks at once). |
|
98
|
|
|
|
|
99
|
|
|
Note that training in inference mode should be avoided, as this |
|
100
|
|
|
brick introduces scales and shift parameters (tagged with the |
|
101
|
|
|
`PARAMETER` role) that, in the absence of batch normalization, |
|
102
|
|
|
usually makes things unstable. If you must do this, filter for and |
|
103
|
|
|
remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER` |
|
104
|
|
|
from the list of parameters you are training, and this brick should |
|
105
|
|
|
behave as a (somewhat expensive) no-op. |
|
106
|
|
|
|
|
107
|
|
|
This Brick accepts `scale_init` and `shift_init` arguments but is |
|
108
|
|
|
*not* an instance of :class:`~blocks.bricks.Initializable`, and will |
|
109
|
|
|
therefore not receive pushed initialization config from any parent |
|
110
|
|
|
brick. In almost all cases, you will probably want to stick with the |
|
111
|
|
|
defaults (unit scale and zero offset), but you can explicitly pass one |
|
112
|
|
|
or both initializers to override this. |
|
113
|
|
|
|
|
114
|
|
|
This has the necessary properties to be inserted into a |
|
115
|
|
|
:class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case |
|
116
|
|
|
the `input_dim` should be omitted at construction, to be inferred from |
|
117
|
|
|
the layer below. |
|
118
|
|
|
|
|
119
|
|
|
|
|
120
|
|
|
.. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization: |
|
121
|
|
|
accelerating deep network training by reducing internal covariate |
|
122
|
|
|
shift*. ICML (2015), pp. 448-456. |
|
123
|
|
|
|
|
124
|
|
|
.. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight |
|
125
|
|
|
normalization: a simple reparameterization to accelerate training |
|
126
|
|
|
of deep neural networks*. arXiv 1602.07868. |
|
127
|
|
|
|
|
128
|
|
|
""" |
|
129
|
|
|
@lazy(allocation=['input_dim']) |
|
130
|
|
|
def __init__(self, input_dim, broadcastable=None, |
|
131
|
|
|
conserve_memory=True, epsilon=1e-4, scale_init=None, |
|
132
|
|
|
shift_init=None, mean_only=False, learn_scale=True, |
|
133
|
|
|
learn_shift=True, **kwargs): |
|
134
|
|
|
self.input_dim = input_dim |
|
135
|
|
|
self.broadcastable = broadcastable |
|
136
|
|
|
self.conserve_memory = conserve_memory |
|
137
|
|
|
self.epsilon = epsilon |
|
138
|
|
|
self.learn_scale = learn_scale |
|
139
|
|
|
self.learn_shift = learn_shift |
|
140
|
|
|
self.scale_init = (Constant(1) if scale_init is None |
|
141
|
|
|
else scale_init) |
|
142
|
|
|
self.shift_init = (Constant(0) if shift_init is None |
|
143
|
|
|
else shift_init) |
|
144
|
|
|
self.mean_only = mean_only |
|
145
|
|
|
self._training_mode = [] |
|
146
|
|
|
super(BatchNormalization, self).__init__(**kwargs) |
|
147
|
|
|
|
|
148
|
|
|
@application(inputs=['input_'], outputs=['output']) |
|
149
|
|
|
def apply(self, input_, application_call): |
|
150
|
|
|
if self._training_mode: |
|
151
|
|
|
mean, stdev = self._compute_training_statistics(input_) |
|
152
|
|
|
else: |
|
153
|
|
|
mean, stdev = self._prepare_population_statistics() |
|
154
|
|
|
# Useful for filtration of calls that were already made in |
|
155
|
|
|
# training mode when doing graph transformations. |
|
156
|
|
|
# Very important to cast to bool, as self._training_mode is |
|
157
|
|
|
# normally a list (to support nested context managers), which would |
|
158
|
|
|
# otherwise get passed by reference and be remotely mutated. |
|
159
|
|
|
application_call.metadata['training_mode'] = bool(self._training_mode) |
|
160
|
|
|
# Useful for retrieving a list of updates for population |
|
161
|
|
|
# statistics. Ditch the broadcastable first axis, though, to |
|
162
|
|
|
# make it the same dimensions as the population mean and stdev |
|
163
|
|
|
# shared variables. |
|
164
|
|
|
application_call.metadata['offset'] = mean[0] |
|
165
|
|
|
application_call.metadata['divisor'] = stdev[0] |
|
166
|
|
|
# Give these quantities roles in the graph. |
|
167
|
|
|
_add_role_and_annotate(mean, BATCH_NORM_OFFSET, |
|
168
|
|
|
[self, application_call]) |
|
169
|
|
|
if self.mean_only: |
|
170
|
|
|
stdev = tensor.ones_like(mean) |
|
171
|
|
|
else: |
|
172
|
|
|
# The annotation/role information is useless if it's a constant. |
|
173
|
|
|
_add_role_and_annotate(stdev, BATCH_NORM_DIVISOR, |
|
174
|
|
|
[self, application_call]) |
|
175
|
|
|
shift = _add_batch_axis(self.shift) |
|
176
|
|
|
scale = _add_batch_axis(self.scale) |
|
177
|
|
|
# Heavy lifting is done by the Theano utility function. |
|
178
|
|
|
normalized = bn.batch_normalization(input_, scale, shift, mean, stdev, |
|
179
|
|
|
mode=('low_mem' |
|
180
|
|
|
if self.conserve_memory |
|
181
|
|
|
else 'high_mem')) |
|
182
|
|
|
return normalized |
|
183
|
|
|
|
|
184
|
|
|
def __enter__(self): |
|
185
|
|
|
self._training_mode.append(True) |
|
186
|
|
|
|
|
187
|
|
|
def __exit__(self, *exc_info): |
|
188
|
|
|
self._training_mode.pop() |
|
189
|
|
|
|
|
190
|
|
|
def _compute_training_statistics(self, input_, axes=None): |
|
191
|
|
|
if axes is None: |
|
192
|
|
|
axes = self.normalization_axes |
|
193
|
|
|
mean = input_.mean(axis=axes, keepdims=True) |
|
194
|
|
|
add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE) |
|
195
|
|
|
if self.mean_only: |
|
196
|
|
|
stdev = tensor.ones_like(mean) |
|
197
|
|
|
else: |
|
198
|
|
|
# We already have the mean; this saves Theano the trouble of |
|
199
|
|
|
# optimizing the graph produced by tensor.var(). |
|
200
|
|
|
# This two pass version is going to be slightly more stable than |
|
201
|
|
|
# E[X^2] - E[X]^2, at least when n is small. It's also never |
|
202
|
|
|
# going to be negative due to numerical error. |
|
203
|
|
|
var = tensor.mean(tensor.sqr(input_ - mean), |
|
204
|
|
|
axis=axes, keepdims=True) |
|
205
|
|
|
eps = numpy.cast[theano.config.floatX](self.epsilon) |
|
206
|
|
|
stdev = tensor.sqrt(var + eps) |
|
207
|
|
|
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE) |
|
208
|
|
|
return mean, stdev |
|
209
|
|
|
|
|
210
|
|
|
@property |
|
211
|
|
|
def normalization_axes(self): |
|
212
|
|
|
return (0,) + tuple((i + 1) for i, b in |
|
213
|
|
|
enumerate(self.population_mean.broadcastable) |
|
214
|
|
|
if b) |
|
215
|
|
|
|
|
216
|
|
|
def _prepare_population_statistics(self): |
|
217
|
|
|
mean = _add_batch_axis(self.population_mean) |
|
218
|
|
|
if self.mean_only: |
|
219
|
|
|
stdev = tensor.ones_like(self.population_mean) |
|
220
|
|
|
else: |
|
221
|
|
|
stdev = self.population_stdev |
|
222
|
|
|
stdev = _add_batch_axis(stdev) |
|
223
|
|
|
return mean, stdev |
|
224
|
|
|
|
|
225
|
|
|
def _allocate(self): |
|
226
|
|
|
input_dim = ((self.input_dim,) |
|
227
|
|
|
if not isinstance(self.input_dim, collections.Sequence) |
|
228
|
|
|
else self.input_dim) |
|
229
|
|
|
broadcastable = (tuple(False for _ in input_dim) |
|
230
|
|
|
if self.broadcastable is None else self.broadcastable) |
|
231
|
|
|
if len(input_dim) != len(broadcastable): |
|
232
|
|
|
raise ValueError("input_dim and broadcastable must be same length") |
|
233
|
|
|
var_dim = tuple(1 if broadcast else dim for dim, broadcast in |
|
234
|
|
|
equizip(input_dim, broadcastable)) |
|
235
|
|
|
broadcastable = broadcastable |
|
236
|
|
|
|
|
237
|
|
|
# "beta", from the Ioffe & Szegedy manuscript. |
|
238
|
|
|
if self.learn_shift: |
|
239
|
|
|
self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift', |
|
240
|
|
|
broadcastable=broadcastable) |
|
241
|
|
|
add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER) |
|
242
|
|
|
self.parameters.append(self.shift) |
|
243
|
|
|
else: |
|
244
|
|
|
self.shift = tensor.constant(0, dtype=theano.config.floatX) |
|
245
|
|
|
|
|
246
|
|
|
if self.learn_scale and not self.mean_only: |
|
247
|
|
|
# "gamma", from the Ioffe & Szegedy manuscript. |
|
248
|
|
|
self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale', |
|
249
|
|
|
broadcastable=broadcastable) |
|
250
|
|
|
|
|
251
|
|
|
add_role(self.scale, BATCH_NORM_SCALE_PARAMETER) |
|
252
|
|
|
self.parameters.append(self.scale) |
|
253
|
|
|
else: |
|
254
|
|
|
self.scale = tensor.constant(1., dtype=theano.config.floatX) |
|
255
|
|
|
|
|
256
|
|
|
self._allocate_population_statistics(var_dim, broadcastable) |
|
257
|
|
|
|
|
258
|
|
|
def _allocate_population_statistics(self, var_dim, broadcastable): |
|
259
|
|
|
def _allocate_buffer(name, role, value): |
|
260
|
|
|
# These aren't technically parameters, in that they should not be |
|
261
|
|
|
# learned using the same cost function as other model parameters. |
|
262
|
|
|
population_buffer = shared_floatx(value * numpy.ones(var_dim), |
|
263
|
|
|
broadcastable=broadcastable, |
|
264
|
|
|
name=name) |
|
265
|
|
|
add_role(population_buffer, role) |
|
266
|
|
|
# Normally these would get annotated by an AnnotatingList, but they |
|
267
|
|
|
# aren't in self.parameters. |
|
268
|
|
|
add_annotation(population_buffer, self) |
|
269
|
|
|
return population_buffer |
|
270
|
|
|
|
|
271
|
|
|
self.population_mean = _allocate_buffer('population_mean', |
|
272
|
|
|
BATCH_NORM_POPULATION_MEAN, |
|
273
|
|
|
numpy.zeros(var_dim)) |
|
274
|
|
|
|
|
275
|
|
|
self.population_stdev = _allocate_buffer('population_stdev', |
|
276
|
|
|
BATCH_NORM_POPULATION_STDEV, |
|
277
|
|
|
numpy.ones(var_dim)) |
|
278
|
|
|
|
|
279
|
|
|
def _initialize(self): |
|
280
|
|
|
# We gate with is_shared_variable rather than relying on |
|
281
|
|
|
# learn_scale and learn_shift so as to avoid the unlikely but nasty |
|
282
|
|
|
# scenario where those flags are changed post-allocation but |
|
283
|
|
|
# pre-initialization. This ensures that such a change simply has no |
|
284
|
|
|
# effect rather than doing an inconsistent combination of things. |
|
285
|
|
|
if is_shared_variable(self.shift): |
|
286
|
|
|
self.shift_init.initialize(self.shift, self.rng) |
|
287
|
|
|
if is_shared_variable(self.scale): |
|
288
|
|
|
self.scale_init.initialize(self.scale, self.rng) |
|
289
|
|
|
|
|
290
|
|
|
# Needed for the Feedforward interface. |
|
291
|
|
|
@property |
|
292
|
|
|
def output_dim(self): |
|
293
|
|
|
return self.input_dim |
|
294
|
|
|
|
|
295
|
|
|
# The following properties allow for BatchNormalization bricks |
|
296
|
|
|
# to be used directly inside of a ConvolutionalSequence. |
|
297
|
|
|
@property |
|
298
|
|
|
def image_size(self): |
|
299
|
|
|
return self.input_dim[-2:] |
|
300
|
|
|
|
|
301
|
|
|
@image_size.setter |
|
302
|
|
|
def image_size(self, value): |
|
303
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
|
304
|
|
|
self.input_dim = (None,) + tuple(value) |
|
305
|
|
|
else: |
|
306
|
|
|
self.input_dim = (self.input_dim[0],) + tuple(value) |
|
307
|
|
|
|
|
308
|
|
|
@property |
|
309
|
|
|
def num_channels(self): |
|
310
|
|
|
return self.input_dim[0] |
|
311
|
|
|
|
|
312
|
|
|
@num_channels.setter |
|
313
|
|
|
def num_channels(self, value): |
|
314
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
|
315
|
|
|
self.input_dim = (value,) + (None, None) |
|
316
|
|
|
else: |
|
317
|
|
|
self.input_dim = (value,) + self.input_dim[-2:] |
|
318
|
|
|
|
|
319
|
|
|
def get_dim(self, name): |
|
320
|
|
|
if name in ('input', 'output'): |
|
321
|
|
|
return self.input_dim |
|
322
|
|
|
else: |
|
323
|
|
|
raise KeyError |
|
324
|
|
|
|
|
325
|
|
|
@property |
|
326
|
|
|
def num_output_channels(self): |
|
327
|
|
|
return self.num_channels |
|
328
|
|
|
|
|
329
|
|
|
|
|
330
|
|
|
class SpatialBatchNormalization(BatchNormalization): |
|
331
|
|
|
"""Convenient subclass for batch normalization across spatial inputs. |
|
332
|
|
|
|
|
333
|
|
|
Parameters |
|
334
|
|
|
---------- |
|
335
|
|
|
input_dim : int or tuple |
|
336
|
|
|
The input size of a single example. Must be length at least 2. |
|
337
|
|
|
It's assumed that the first axis of this tuple is a "channels" |
|
338
|
|
|
axis, which should not be summed over, and all remaining |
|
339
|
|
|
dimensions are spatial dimensions. |
|
340
|
|
|
|
|
341
|
|
|
Notes |
|
342
|
|
|
----- |
|
343
|
|
|
See :class:`BatchNormalization` for more details (and additional |
|
344
|
|
|
keyword arguments). |
|
345
|
|
|
|
|
346
|
|
|
""" |
|
347
|
|
|
def _allocate(self): |
|
348
|
|
|
if not isinstance(self.input_dim, |
|
349
|
|
|
collections.Sequence) or len(self.input_dim) < 2: |
|
350
|
|
|
raise ValueError('expected input_dim to be length >= 2 ' |
|
351
|
|
|
'e.g. (channels, height, width)') |
|
352
|
|
|
self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1)) |
|
353
|
|
|
super(SpatialBatchNormalization, self)._allocate() |
|
354
|
|
|
|
|
355
|
|
|
|
|
356
|
|
|
class BatchNormalizedMLP(MLP): |
|
357
|
|
|
"""Convenient subclass for building an MLP with batch normalization. |
|
358
|
|
|
|
|
359
|
|
|
Parameters |
|
360
|
|
|
---------- |
|
361
|
|
|
conserve_memory : bool, optional, by keyword only |
|
362
|
|
|
See :class:`BatchNormalization`. |
|
363
|
|
|
mean_only : bool, optional, by keyword only |
|
364
|
|
|
See :class:`BatchNormalization`. |
|
365
|
|
|
learn_scale : bool, optional, by keyword only |
|
366
|
|
|
See :class:`BatchNormalization`. |
|
367
|
|
|
learn_shift : bool, optional, by keyword only |
|
368
|
|
|
See :class:`BatchNormalization`. |
|
369
|
|
|
|
|
370
|
|
|
Notes |
|
371
|
|
|
----- |
|
372
|
|
|
All other parameters are the same as :class:`~blocks.bricks.MLP`. Each |
|
373
|
|
|
activation brick is wrapped in a :class:`~blocks.bricks.Sequence` |
|
374
|
|
|
containing an appropriate :class:`BatchNormalization` brick and |
|
375
|
|
|
the activation that follows it. |
|
376
|
|
|
|
|
377
|
|
|
By default, the contained :class:`~blocks.bricks.Linear` bricks will |
|
378
|
|
|
not contain any biases, as they could be canceled out by the biases |
|
379
|
|
|
in the :class:`BatchNormalization` bricks being added. Pass |
|
380
|
|
|
`use_bias` with a value of `True` if you really want this for some |
|
381
|
|
|
reason. |
|
382
|
|
|
|
|
383
|
|
|
`mean_only`, `learn_scale` and `learn_shift` are pushed down to |
|
384
|
|
|
all created :class:`BatchNormalization` bricks as allocation |
|
385
|
|
|
config. |
|
386
|
|
|
|
|
387
|
|
|
""" |
|
388
|
|
|
@lazy(allocation=['dims']) |
|
389
|
|
|
def __init__(self, activations, dims, *args, **kwargs): |
|
390
|
|
|
self._conserve_memory = kwargs.pop('conserve_memory', True) |
|
391
|
|
|
self.mean_only = kwargs.pop('mean_only', False) |
|
392
|
|
|
self.learn_scale = kwargs.pop('learn_scale', True) |
|
393
|
|
|
self.learn_shift = kwargs.pop('learn_shift', True) |
|
394
|
|
|
|
|
395
|
|
|
activations = [ |
|
396
|
|
|
Sequence([ |
|
397
|
|
|
(BatchNormalization(conserve_memory=self._conserve_memory) |
|
398
|
|
|
.apply), |
|
399
|
|
|
act.apply |
|
400
|
|
|
], name='batch_norm_activation_{}'.format(i)) |
|
401
|
|
|
for i, act in enumerate(activations) |
|
402
|
|
|
] |
|
403
|
|
|
# Batch normalization bricks incorporate a bias, so there's no |
|
404
|
|
|
# need for our Linear bricks to have them. |
|
405
|
|
|
kwargs.setdefault('use_bias', False) |
|
406
|
|
|
super(BatchNormalizedMLP, self).__init__(activations, dims, *args, |
|
407
|
|
|
**kwargs) |
|
408
|
|
|
|
|
409
|
|
|
def _nested_brick_property_getter(self, property_name): |
|
410
|
|
|
return getattr(self, '_' + property_name) |
|
411
|
|
|
|
|
412
|
|
|
def _nested_brick_property_setter(self, value, property_name): |
|
413
|
|
|
setattr(self, '_' + property_name, value) |
|
414
|
|
|
for act in self.activations: |
|
415
|
|
|
assert isinstance(act.children[0], BatchNormalization) |
|
416
|
|
|
setattr(act.children[0], property_name, value) |
|
417
|
|
|
|
|
418
|
|
|
# conserve_memory is a bit special in that it can be modified |
|
419
|
|
|
# after construction/allocation and still have a valid effect on |
|
420
|
|
|
# apply(). Thus we propagate down all property sets. |
|
421
|
|
|
conserve_memory = property(partial(_nested_brick_property_getter, |
|
422
|
|
|
property_name='conserve_memory'), |
|
423
|
|
|
partial(_nested_brick_property_setter, |
|
424
|
|
|
property_name='conserve_memory'), |
|
425
|
|
|
doc="Conserve memory.") |
|
426
|
|
|
|
|
427
|
|
|
def _push_allocation_config(self): |
|
428
|
|
|
super(BatchNormalizedMLP, self)._push_allocation_config() |
|
429
|
|
|
# Do the extra allocation pushing for the BatchNormalization |
|
430
|
|
|
# bricks. They need as their input dimension the output dimension |
|
431
|
|
|
# of each linear transformation. Exclude the first dimension, |
|
432
|
|
|
# which is the input dimension. |
|
433
|
|
|
for act, dim in equizip(self.activations, self.dims[1:]): |
|
434
|
|
|
assert isinstance(act.children[0], BatchNormalization) |
|
435
|
|
|
act.children[0].input_dim = dim |
|
436
|
|
|
for attr in ['mean_only', 'learn_scale', 'learn_shift']: |
|
437
|
|
|
setattr(act.children[0], attr, getattr(self, attr)) |
|
438
|
|
|
|