1
|
|
|
import collections |
2
|
|
|
from functools import partial |
3
|
|
|
|
4
|
|
|
import numpy |
5
|
|
|
from picklable_itertools.extras import equizip |
6
|
|
|
import theano |
7
|
|
|
from theano import tensor |
8
|
|
|
from theano.tensor.nnet import bn |
9
|
|
|
|
10
|
|
|
from ..graph import add_annotation |
11
|
|
|
from ..initialization import Constant |
12
|
|
|
from ..roles import (BATCH_NORM_POPULATION_MEAN, |
13
|
|
|
BATCH_NORM_POPULATION_STDEV, BATCH_NORM_OFFSET, |
14
|
|
|
BATCH_NORM_DIVISOR, BATCH_NORM_MINIBATCH_ESTIMATE, |
15
|
|
|
BATCH_NORM_SHIFT_PARAMETER, BATCH_NORM_SCALE_PARAMETER, |
16
|
|
|
add_role) |
17
|
|
|
from ..utils import shared_floatx, shared_floatx_nans, is_shared_variable |
18
|
|
|
from .base import lazy, application |
19
|
|
|
from .sequences import Sequence, Feedforward, MLP |
20
|
|
|
from .interfaces import RNGMixin |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
def _add_batch_axis(var): |
24
|
|
|
"""Prepend a singleton axis to a TensorVariable and name it.""" |
25
|
|
|
new_var = new_var = tensor.shape_padleft(var) |
26
|
|
|
new_var.name = 'shape_padleft({})'.format(var.name) |
27
|
|
|
return new_var |
28
|
|
|
|
29
|
|
|
|
30
|
|
|
def _add_role_and_annotate(var, role, annotations=()): |
31
|
|
|
"""Add a role and zero or more annotations to a variable.""" |
32
|
|
|
add_role(var, role) |
33
|
|
|
for annotation in annotations: |
34
|
|
|
add_annotation(var, annotation) |
35
|
|
|
|
36
|
|
|
|
37
|
|
|
class BatchNormalization(RNGMixin, Feedforward): |
38
|
|
|
r"""Normalizes activations, parameterizes a scale and shift. |
39
|
|
|
|
40
|
|
|
Parameters |
41
|
|
|
---------- |
42
|
|
|
input_dim : int or tuple |
43
|
|
|
Shape of a single input example. It is assumed that a batch axis |
44
|
|
|
will be prepended to this. |
45
|
|
|
broadcastable : tuple, optional |
46
|
|
|
Tuple of the same length as `input_dim` which specifies which of |
47
|
|
|
the per-example axes should be averaged over to compute means and |
48
|
|
|
standard deviations. For example, in order to normalize over all |
49
|
|
|
spatial locations in a `(batch_index, channels, height, width)` |
50
|
|
|
image, pass `(False, True, True)`. The batch axis is always |
51
|
|
|
averaged out. |
52
|
|
|
conserve_memory : bool, optional |
53
|
|
|
Use an implementation that stores less intermediate state and |
54
|
|
|
therefore uses less memory, at the expense of 5-10% speed. Default |
55
|
|
|
is `True`. |
56
|
|
|
epsilon : float, optional |
57
|
|
|
The stabilizing constant for the minibatch standard deviation |
58
|
|
|
computation (when the brick is run in training mode). |
59
|
|
|
Added to the variance inside the square root, as in the |
60
|
|
|
batch normalization paper. |
61
|
|
|
scale_init : object, optional |
62
|
|
|
Initialization object to use for the learned scaling parameter |
63
|
|
|
($\\gamma$ in [BN]_). By default, uses constant initialization |
64
|
|
|
of 1. |
65
|
|
|
shift_init : object, optional |
66
|
|
|
Initialization object to use for the learned shift parameter |
67
|
|
|
($\\beta$ in [BN]_). By default, uses constant initialization of 0. |
68
|
|
|
mean_only : bool, optional |
69
|
|
|
Perform "mean-only" batch normalization as described in [SK2016]_. |
70
|
|
|
learn_scale : bool, optional |
71
|
|
|
Whether to include a learned scale parameter ($\\gamma$ in [BN]_) |
72
|
|
|
in this brick. Default is `True`. Has no effect if `mean_only` is |
73
|
|
|
`True` (i.e. a scale parameter is never learned in mean-only mode). |
74
|
|
|
learn_shift : bool, optional |
75
|
|
|
Whether to include a learned shift parameter ($\\beta$ in [BN]_) |
76
|
|
|
in this brick. Default is `True`. |
77
|
|
|
|
78
|
|
|
Notes |
79
|
|
|
----- |
80
|
|
|
In order for trained models to behave sensibly immediately upon |
81
|
|
|
upon deserialization, by default, this brick runs in *inference* mode, |
82
|
|
|
using a population mean and population standard deviation (initialized |
83
|
|
|
to zeros and ones respectively) to normalize activations. It is |
84
|
|
|
expected that the user will adapt these during training in some |
85
|
|
|
fashion, independently of the training objective, e.g. by taking a |
86
|
|
|
moving average of minibatch-wise statistics. |
87
|
|
|
|
88
|
|
|
In order to *train* with batch normalization, one must obtain a |
89
|
|
|
training graph by transforming the original inference graph. See |
90
|
|
|
:func:`~blocks.graph.apply_batch_normalization` for a routine to |
91
|
|
|
transform graphs, and :func:`~blocks.graph.batch_normalization` |
92
|
|
|
for a context manager that may enable shorter compile times |
93
|
|
|
(every instance of :class:`BatchNormalization` is itself a context |
94
|
|
|
manager, entry into which causes applications to be in minibatch |
95
|
|
|
"training" mode, however it is usually more convenient to use |
96
|
|
|
:func:`~blocks.graph.batch_normalization` to enable this behaviour |
97
|
|
|
for all of your graph's :class:`BatchNormalization` bricks at once). |
98
|
|
|
|
99
|
|
|
Note that training in inference mode should be avoided, as this |
100
|
|
|
brick introduces scales and shift parameters (tagged with the |
101
|
|
|
`PARAMETER` role) that, in the absence of batch normalization, |
102
|
|
|
usually makes things unstable. If you must do this, filter for and |
103
|
|
|
remove `BATCH_NORM_SHIFT_PARAMETER` and `BATCH_NORM_SCALE_PARAMETER` |
104
|
|
|
from the list of parameters you are training, and this brick should |
105
|
|
|
behave as a (somewhat expensive) no-op. |
106
|
|
|
|
107
|
|
|
This Brick accepts `scale_init` and `shift_init` arguments but is |
108
|
|
|
*not* an instance of :class:`~blocks.bricks.Initializable`, and will |
109
|
|
|
therefore not receive pushed initialization config from any parent |
110
|
|
|
brick. In almost all cases, you will probably want to stick with the |
111
|
|
|
defaults (unit scale and zero offset), but you can explicitly pass one |
112
|
|
|
or both initializers to override this. |
113
|
|
|
|
114
|
|
|
This has the necessary properties to be inserted into a |
115
|
|
|
:class:`blocks.bricks.conv.ConvolutionalSequence` as-is, in which case |
116
|
|
|
the `input_dim` should be omitted at construction, to be inferred from |
117
|
|
|
the layer below. |
118
|
|
|
|
119
|
|
|
|
120
|
|
|
.. [BN] Sergey Ioffe and Christian Szegedy. *Batch normalization: |
121
|
|
|
accelerating deep network training by reducing internal covariate |
122
|
|
|
shift*. ICML (2015), pp. 448-456. |
123
|
|
|
|
124
|
|
|
.. [SK2016] Tim Salimans and Diederik P. Kingma. *Weight |
125
|
|
|
normalization: a simple reparameterization to accelerate training |
126
|
|
|
of deep neural networks*. arXiv 1602.07868. |
127
|
|
|
|
128
|
|
|
""" |
129
|
|
|
@lazy(allocation=['input_dim']) |
130
|
|
|
def __init__(self, input_dim, broadcastable=None, |
131
|
|
|
conserve_memory=True, epsilon=1e-4, scale_init=None, |
132
|
|
|
shift_init=None, mean_only=False, learn_scale=True, |
133
|
|
|
learn_shift=True, **kwargs): |
134
|
|
|
self.input_dim = input_dim |
135
|
|
|
self.broadcastable = broadcastable |
136
|
|
|
self.conserve_memory = conserve_memory |
137
|
|
|
self.epsilon = epsilon |
138
|
|
|
self.learn_scale = learn_scale |
139
|
|
|
self.learn_shift = learn_shift |
140
|
|
|
self.scale_init = (Constant(1) if scale_init is None |
141
|
|
|
else scale_init) |
142
|
|
|
self.shift_init = (Constant(0) if shift_init is None |
143
|
|
|
else shift_init) |
144
|
|
|
self.mean_only = mean_only |
145
|
|
|
self._training_mode = [] |
146
|
|
|
super(BatchNormalization, self).__init__(**kwargs) |
147
|
|
|
|
148
|
|
|
@application(inputs=['input_'], outputs=['output']) |
149
|
|
|
def apply(self, input_, application_call): |
150
|
|
|
if self._training_mode: |
151
|
|
|
mean, stdev = self._compute_training_statistics(input_) |
152
|
|
|
else: |
153
|
|
|
mean, stdev = self._prepare_population_statistics() |
154
|
|
|
# Useful for filtration of calls that were already made in |
155
|
|
|
# training mode when doing graph transformations. |
156
|
|
|
# Very important to cast to bool, as self._training_mode is |
157
|
|
|
# normally a list (to support nested context managers), which would |
158
|
|
|
# otherwise get passed by reference and be remotely mutated. |
159
|
|
|
application_call.metadata['training_mode'] = bool(self._training_mode) |
160
|
|
|
# Useful for retrieving a list of updates for population |
161
|
|
|
# statistics. Ditch the broadcastable first axis, though, to |
162
|
|
|
# make it the same dimensions as the population mean and stdev |
163
|
|
|
# shared variables. |
164
|
|
|
application_call.metadata['offset'] = mean[0] |
165
|
|
|
application_call.metadata['divisor'] = stdev[0] |
166
|
|
|
# Give these quantities roles in the graph. |
167
|
|
|
_add_role_and_annotate(mean, BATCH_NORM_OFFSET, |
168
|
|
|
[self, application_call]) |
169
|
|
|
if self.mean_only: |
170
|
|
|
stdev = tensor.ones_like(mean) |
171
|
|
|
else: |
172
|
|
|
# The annotation/role information is useless if it's a constant. |
173
|
|
|
_add_role_and_annotate(stdev, BATCH_NORM_DIVISOR, |
174
|
|
|
[self, application_call]) |
175
|
|
|
shift = _add_batch_axis(self.shift) |
176
|
|
|
scale = _add_batch_axis(self.scale) |
177
|
|
|
# Heavy lifting is done by the Theano utility function. |
178
|
|
|
normalized = bn.batch_normalization(input_, scale, shift, mean, stdev, |
179
|
|
|
mode=('low_mem' |
180
|
|
|
if self.conserve_memory |
181
|
|
|
else 'high_mem')) |
182
|
|
|
return normalized |
183
|
|
|
|
184
|
|
|
def __enter__(self): |
185
|
|
|
self._training_mode.append(True) |
186
|
|
|
|
187
|
|
|
def __exit__(self, *exc_info): |
188
|
|
|
self._training_mode.pop() |
189
|
|
|
|
190
|
|
|
def _compute_training_statistics(self, input_, axes=None): |
191
|
|
|
if axes is None: |
192
|
|
|
axes = self.normalization_axes |
193
|
|
|
mean = input_.mean(axis=axes, keepdims=True) |
194
|
|
|
add_role(mean, BATCH_NORM_MINIBATCH_ESTIMATE) |
195
|
|
|
if self.mean_only: |
196
|
|
|
stdev = tensor.ones_like(mean) |
197
|
|
|
else: |
198
|
|
|
# We already have the mean; this saves Theano the trouble of |
199
|
|
|
# optimizing the graph produced by tensor.var(). |
200
|
|
|
# This two pass version is going to be slightly more stable than |
201
|
|
|
# E[X^2] - E[X]^2, at least when n is small. It's also never |
202
|
|
|
# going to be negative due to numerical error. |
203
|
|
|
var = tensor.mean(tensor.sqr(input_ - mean), |
204
|
|
|
axis=axes, keepdims=True) |
205
|
|
|
eps = numpy.cast[theano.config.floatX](self.epsilon) |
206
|
|
|
stdev = tensor.sqrt(var + eps) |
207
|
|
|
add_role(stdev, BATCH_NORM_MINIBATCH_ESTIMATE) |
208
|
|
|
return mean, stdev |
209
|
|
|
|
210
|
|
|
@property |
211
|
|
|
def normalization_axes(self): |
212
|
|
|
return (0,) + tuple((i + 1) for i, b in |
213
|
|
|
enumerate(self.population_mean.broadcastable) |
214
|
|
|
if b) |
215
|
|
|
|
216
|
|
|
def _prepare_population_statistics(self): |
217
|
|
|
mean = _add_batch_axis(self.population_mean) |
218
|
|
|
if self.mean_only: |
219
|
|
|
stdev = tensor.ones_like(self.population_mean) |
220
|
|
|
else: |
221
|
|
|
stdev = self.population_stdev |
222
|
|
|
stdev = _add_batch_axis(stdev) |
223
|
|
|
return mean, stdev |
224
|
|
|
|
225
|
|
|
def _allocate(self): |
226
|
|
|
input_dim = ((self.input_dim,) |
227
|
|
|
if not isinstance(self.input_dim, collections.Sequence) |
228
|
|
|
else self.input_dim) |
229
|
|
|
broadcastable = (tuple(False for _ in input_dim) |
230
|
|
|
if self.broadcastable is None else self.broadcastable) |
231
|
|
|
if len(input_dim) != len(broadcastable): |
232
|
|
|
raise ValueError("input_dim and broadcastable must be same length") |
233
|
|
|
var_dim = tuple(1 if broadcast else dim for dim, broadcast in |
234
|
|
|
equizip(input_dim, broadcastable)) |
235
|
|
|
broadcastable = broadcastable |
236
|
|
|
|
237
|
|
|
# "beta", from the Ioffe & Szegedy manuscript. |
238
|
|
|
if self.learn_shift: |
239
|
|
|
self.shift = shared_floatx_nans(var_dim, name='batch_norm_shift', |
240
|
|
|
broadcastable=broadcastable) |
241
|
|
|
add_role(self.shift, BATCH_NORM_SHIFT_PARAMETER) |
242
|
|
|
self.parameters.append(self.shift) |
243
|
|
|
else: |
244
|
|
|
self.shift = tensor.constant(0, dtype=theano.config.floatX) |
245
|
|
|
|
246
|
|
|
if self.learn_scale and not self.mean_only: |
247
|
|
|
# "gamma", from the Ioffe & Szegedy manuscript. |
248
|
|
|
self.scale = shared_floatx_nans(var_dim, name='batch_norm_scale', |
249
|
|
|
broadcastable=broadcastable) |
250
|
|
|
|
251
|
|
|
add_role(self.scale, BATCH_NORM_SCALE_PARAMETER) |
252
|
|
|
self.parameters.append(self.scale) |
253
|
|
|
else: |
254
|
|
|
self.scale = tensor.constant(1., dtype=theano.config.floatX) |
255
|
|
|
|
256
|
|
|
self._allocate_population_statistics(var_dim, broadcastable) |
257
|
|
|
|
258
|
|
|
def _allocate_population_statistics(self, var_dim, broadcastable): |
259
|
|
|
def _allocate_buffer(name, role, value): |
260
|
|
|
# These aren't technically parameters, in that they should not be |
261
|
|
|
# learned using the same cost function as other model parameters. |
262
|
|
|
population_buffer = shared_floatx(value * numpy.ones(var_dim), |
263
|
|
|
broadcastable=broadcastable, |
264
|
|
|
name=name) |
265
|
|
|
add_role(population_buffer, role) |
266
|
|
|
# Normally these would get annotated by an AnnotatingList, but they |
267
|
|
|
# aren't in self.parameters. |
268
|
|
|
add_annotation(population_buffer, self) |
269
|
|
|
return population_buffer |
270
|
|
|
|
271
|
|
|
self.population_mean = _allocate_buffer('population_mean', |
272
|
|
|
BATCH_NORM_POPULATION_MEAN, |
273
|
|
|
numpy.zeros(var_dim)) |
274
|
|
|
|
275
|
|
|
self.population_stdev = _allocate_buffer('population_stdev', |
276
|
|
|
BATCH_NORM_POPULATION_STDEV, |
277
|
|
|
numpy.ones(var_dim)) |
278
|
|
|
|
279
|
|
|
def _initialize(self): |
280
|
|
|
# We gate with is_shared_variable rather than relying on |
281
|
|
|
# learn_scale and learn_shift so as to avoid the unlikely but nasty |
282
|
|
|
# scenario where those flags are changed post-allocation but |
283
|
|
|
# pre-initialization. This ensures that such a change simply has no |
284
|
|
|
# effect rather than doing an inconsistent combination of things. |
285
|
|
|
if is_shared_variable(self.shift): |
286
|
|
|
self.shift_init.initialize(self.shift, self.rng) |
287
|
|
|
if is_shared_variable(self.scale): |
288
|
|
|
self.scale_init.initialize(self.scale, self.rng) |
289
|
|
|
|
290
|
|
|
# Needed for the Feedforward interface. |
291
|
|
|
@property |
292
|
|
|
def output_dim(self): |
293
|
|
|
return self.input_dim |
294
|
|
|
|
295
|
|
|
# The following properties allow for BatchNormalization bricks |
296
|
|
|
# to be used directly inside of a ConvolutionalSequence. |
297
|
|
|
@property |
298
|
|
|
def image_size(self): |
299
|
|
|
return self.input_dim[-2:] |
300
|
|
|
|
301
|
|
|
@image_size.setter |
302
|
|
|
def image_size(self, value): |
303
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
304
|
|
|
self.input_dim = (None,) + tuple(value) |
305
|
|
|
else: |
306
|
|
|
self.input_dim = (self.input_dim[0],) + tuple(value) |
307
|
|
|
|
308
|
|
|
@property |
309
|
|
|
def num_channels(self): |
310
|
|
|
return self.input_dim[0] |
311
|
|
|
|
312
|
|
|
@num_channels.setter |
313
|
|
|
def num_channels(self, value): |
314
|
|
|
if not isinstance(self.input_dim, collections.Sequence): |
315
|
|
|
self.input_dim = (value,) + (None, None) |
316
|
|
|
else: |
317
|
|
|
self.input_dim = (value,) + self.input_dim[-2:] |
318
|
|
|
|
319
|
|
|
def get_dim(self, name): |
320
|
|
|
if name in ('input', 'output'): |
321
|
|
|
return self.input_dim |
322
|
|
|
else: |
323
|
|
|
raise KeyError |
324
|
|
|
|
325
|
|
|
@property |
326
|
|
|
def num_output_channels(self): |
327
|
|
|
return self.num_channels |
328
|
|
|
|
329
|
|
|
|
330
|
|
|
class SpatialBatchNormalization(BatchNormalization): |
331
|
|
|
"""Convenient subclass for batch normalization across spatial inputs. |
332
|
|
|
|
333
|
|
|
Parameters |
334
|
|
|
---------- |
335
|
|
|
input_dim : int or tuple |
336
|
|
|
The input size of a single example. Must be length at least 2. |
337
|
|
|
It's assumed that the first axis of this tuple is a "channels" |
338
|
|
|
axis, which should not be summed over, and all remaining |
339
|
|
|
dimensions are spatial dimensions. |
340
|
|
|
|
341
|
|
|
Notes |
342
|
|
|
----- |
343
|
|
|
See :class:`BatchNormalization` for more details (and additional |
344
|
|
|
keyword arguments). |
345
|
|
|
|
346
|
|
|
""" |
347
|
|
|
def _allocate(self): |
348
|
|
|
if not isinstance(self.input_dim, |
349
|
|
|
collections.Sequence) or len(self.input_dim) < 2: |
350
|
|
|
raise ValueError('expected input_dim to be length >= 2 ' |
351
|
|
|
'e.g. (channels, height, width)') |
352
|
|
|
self.broadcastable = (False,) + ((True,) * (len(self.input_dim) - 1)) |
353
|
|
|
super(SpatialBatchNormalization, self)._allocate() |
354
|
|
|
|
355
|
|
|
|
356
|
|
|
class BatchNormalizedMLP(MLP): |
357
|
|
|
"""Convenient subclass for building an MLP with batch normalization. |
358
|
|
|
|
359
|
|
|
Parameters |
360
|
|
|
---------- |
361
|
|
|
conserve_memory : bool, optional, by keyword only |
362
|
|
|
See :class:`BatchNormalization`. |
363
|
|
|
mean_only : bool, optional, by keyword only |
364
|
|
|
See :class:`BatchNormalization`. |
365
|
|
|
learn_scale : bool, optional, by keyword only |
366
|
|
|
See :class:`BatchNormalization`. |
367
|
|
|
learn_shift : bool, optional, by keyword only |
368
|
|
|
See :class:`BatchNormalization`. |
369
|
|
|
|
370
|
|
|
Notes |
371
|
|
|
----- |
372
|
|
|
All other parameters are the same as :class:`~blocks.bricks.MLP`. Each |
373
|
|
|
activation brick is wrapped in a :class:`~blocks.bricks.Sequence` |
374
|
|
|
containing an appropriate :class:`BatchNormalization` brick and |
375
|
|
|
the activation that follows it. |
376
|
|
|
|
377
|
|
|
By default, the contained :class:`~blocks.bricks.Linear` bricks will |
378
|
|
|
not contain any biases, as they could be canceled out by the biases |
379
|
|
|
in the :class:`BatchNormalization` bricks being added. Pass |
380
|
|
|
`use_bias` with a value of `True` if you really want this for some |
381
|
|
|
reason. |
382
|
|
|
|
383
|
|
|
`mean_only`, `learn_scale` and `learn_shift` are pushed down to |
384
|
|
|
all created :class:`BatchNormalization` bricks as allocation |
385
|
|
|
config. |
386
|
|
|
|
387
|
|
|
""" |
388
|
|
|
@lazy(allocation=['dims']) |
389
|
|
|
def __init__(self, activations, dims, *args, **kwargs): |
390
|
|
|
self._conserve_memory = kwargs.pop('conserve_memory', True) |
391
|
|
|
self.mean_only = kwargs.pop('mean_only', False) |
392
|
|
|
self.learn_scale = kwargs.pop('learn_scale', True) |
393
|
|
|
self.learn_shift = kwargs.pop('learn_shift', True) |
394
|
|
|
|
395
|
|
|
activations = [ |
396
|
|
|
Sequence([ |
397
|
|
|
(BatchNormalization(conserve_memory=self._conserve_memory) |
398
|
|
|
.apply), |
399
|
|
|
act.apply |
400
|
|
|
], name='batch_norm_activation_{}'.format(i)) |
401
|
|
|
for i, act in enumerate(activations) |
402
|
|
|
] |
403
|
|
|
# Batch normalization bricks incorporate a bias, so there's no |
404
|
|
|
# need for our Linear bricks to have them. |
405
|
|
|
kwargs.setdefault('use_bias', False) |
406
|
|
|
super(BatchNormalizedMLP, self).__init__(activations, dims, *args, |
407
|
|
|
**kwargs) |
408
|
|
|
|
409
|
|
|
def _nested_brick_property_getter(self, property_name): |
410
|
|
|
return getattr(self, '_' + property_name) |
411
|
|
|
|
412
|
|
|
def _nested_brick_property_setter(self, value, property_name): |
413
|
|
|
setattr(self, '_' + property_name, value) |
414
|
|
|
for act in self.activations: |
415
|
|
|
assert isinstance(act.children[0], BatchNormalization) |
416
|
|
|
setattr(act.children[0], property_name, value) |
417
|
|
|
|
418
|
|
|
# conserve_memory is a bit special in that it can be modified |
419
|
|
|
# after construction/allocation and still have a valid effect on |
420
|
|
|
# apply(). Thus we propagate down all property sets. |
421
|
|
|
conserve_memory = property(partial(_nested_brick_property_getter, |
422
|
|
|
property_name='conserve_memory'), |
423
|
|
|
partial(_nested_brick_property_setter, |
424
|
|
|
property_name='conserve_memory'), |
425
|
|
|
doc="Conserve memory.") |
426
|
|
|
|
427
|
|
|
def _push_allocation_config(self): |
428
|
|
|
super(BatchNormalizedMLP, self)._push_allocation_config() |
429
|
|
|
# Do the extra allocation pushing for the BatchNormalization |
430
|
|
|
# bricks. They need as their input dimension the output dimension |
431
|
|
|
# of each linear transformation. Exclude the first dimension, |
432
|
|
|
# which is the input dimension. |
433
|
|
|
for act, dim in equizip(self.activations, self.dims[1:]): |
434
|
|
|
assert isinstance(act.children[0], BatchNormalization) |
435
|
|
|
act.children[0].input_dim = dim |
436
|
|
|
for attr in ['mean_only', 'learn_scale', 'learn_shift']: |
437
|
|
|
setattr(act.children[0], attr, getattr(self, attr)) |
438
|
|
|
|