|
1
|
|
|
"""Implements the batch normalization training graph transform. |
|
2
|
|
|
|
|
3
|
|
|
Specifically, this module contains the implementation for the |
|
4
|
|
|
transformation of a batch-normalized inference graph into training graph, |
|
5
|
|
|
which uses minibatch statistics in place of population statistics. |
|
6
|
|
|
|
|
7
|
|
|
""" |
|
8
|
|
|
import collections |
|
9
|
|
|
import contextlib |
|
10
|
|
|
from functools import partial |
|
11
|
|
|
|
|
12
|
|
|
import theano |
|
13
|
|
|
from toolz import isdistinct |
|
14
|
|
|
|
|
15
|
|
|
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT |
|
16
|
|
|
from ..utils import find_bricks |
|
17
|
|
|
|
|
18
|
|
|
|
|
19
|
|
|
def _training_mode_application_calls(application_calls): |
|
|
|
|
|
|
20
|
|
|
"""Filter for application calls made in 'training mode'.""" |
|
21
|
|
|
from ..bricks import BatchNormalization |
|
22
|
|
|
out = [] |
|
23
|
|
|
for app_call in application_calls: |
|
24
|
|
|
assert isinstance(app_call.application.brick, BatchNormalization) |
|
25
|
|
|
assert app_call.application.application == BatchNormalization.apply |
|
26
|
|
|
if app_call.metadata.get('training_mode', False): |
|
27
|
|
|
out.append(app_call) |
|
28
|
|
|
return out |
|
29
|
|
|
|
|
30
|
|
|
|
|
31
|
|
|
@contextlib.contextmanager |
|
32
|
|
|
def batch_normalization(*bricks): |
|
33
|
|
|
r"""Context manager to run batch normalization in "training mode". |
|
34
|
|
|
|
|
35
|
|
|
Parameters |
|
36
|
|
|
---------- |
|
37
|
|
|
\*bricks |
|
38
|
|
|
One or more bricks which will be inspected for descendant |
|
39
|
|
|
instances of :class:`~blocks.bricks.BatchNormalization`. |
|
40
|
|
|
|
|
41
|
|
|
Notes |
|
42
|
|
|
----- |
|
43
|
|
|
Graph replacement using :func:`apply_batch_normalization`, while |
|
44
|
|
|
elegant, can lead to Theano graphs that are quite large and result |
|
45
|
|
|
in very slow compiles. This provides an alternative mechanism for |
|
46
|
|
|
building the batch normalized training graph. It can be somewhat |
|
47
|
|
|
less convenient as it requires building the graph twice if one |
|
48
|
|
|
wishes to monitor the output of the inference graph during training. |
|
49
|
|
|
|
|
50
|
|
|
Examples |
|
51
|
|
|
-------- |
|
52
|
|
|
First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
|
53
|
|
|
|
|
54
|
|
|
>>> import theano |
|
55
|
|
|
>>> from blocks.bricks import BatchNormalizedMLP, Tanh |
|
56
|
|
|
>>> from blocks.initialization import Constant, IsotropicGaussian |
|
57
|
|
|
>>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
|
58
|
|
|
... weights_init=IsotropicGaussian(0.1), |
|
59
|
|
|
... biases_init=Constant(0)) |
|
60
|
|
|
>>> mlp.initialize() |
|
61
|
|
|
|
|
62
|
|
|
Now, we'll construct an output variable as we would normally. This |
|
63
|
|
|
is getting normalized by the *population* statistics, which by |
|
64
|
|
|
default are initialized to 0 (mean) and 1 (standard deviation), |
|
65
|
|
|
respectively. |
|
66
|
|
|
|
|
67
|
|
|
>>> x = theano.tensor.matrix() |
|
68
|
|
|
>>> y = mlp.apply(x) |
|
69
|
|
|
|
|
70
|
|
|
And now, to construct an output with batch normalization enabled, |
|
71
|
|
|
i.e. normalizing pre-activations using per-minibatch statistics, we |
|
72
|
|
|
simply make a similar call inside of a `with` statement: |
|
73
|
|
|
|
|
74
|
|
|
>>> with batch_normalization(mlp): |
|
75
|
|
|
... y_bn = mlp.apply(x) |
|
76
|
|
|
|
|
77
|
|
|
Let's verify that these two graphs behave differently on the |
|
78
|
|
|
same data: |
|
79
|
|
|
|
|
80
|
|
|
>>> import numpy |
|
81
|
|
|
>>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
|
82
|
|
|
>>> inf_y = y.eval({x: data}) |
|
83
|
|
|
>>> trn_y = y_bn.eval({x: data}) |
|
84
|
|
|
>>> numpy.allclose(inf_y, trn_y) |
|
85
|
|
|
False |
|
86
|
|
|
|
|
87
|
|
|
""" |
|
88
|
|
|
# Avoid circular imports. |
|
89
|
|
|
from blocks.bricks import BatchNormalization |
|
90
|
|
|
|
|
91
|
|
|
bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization)) |
|
92
|
|
|
# Can't use either nested() (deprecated) nor ExitStack (not available |
|
93
|
|
|
# on Python 2.7). Well, that sucks. |
|
94
|
|
|
try: |
|
95
|
|
|
for brick in bn: |
|
96
|
|
|
brick.__enter__() |
|
97
|
|
|
yield |
|
98
|
|
|
finally: |
|
99
|
|
|
for brick in bn[::-1]: |
|
100
|
|
|
brick.__exit__() |
|
|
|
|
|
|
101
|
|
|
|
|
102
|
|
|
|
|
103
|
|
|
def apply_batch_normalization(computation_graph): |
|
104
|
|
|
"""Transform a graph into a batch-normalized training graph. |
|
105
|
|
|
|
|
106
|
|
|
Parameters |
|
107
|
|
|
---------- |
|
108
|
|
|
computation_graph : :class:`~blocks.graph.ComputationGraph` |
|
109
|
|
|
The computation graph containing :class:`BatchNormalization` |
|
110
|
|
|
brick applications. |
|
111
|
|
|
|
|
112
|
|
|
Returns |
|
113
|
|
|
------- |
|
114
|
|
|
batch_normed_graph : :class:`~blocks.graph.ComputationGraph` |
|
115
|
|
|
The computation graph, with :class:`BatchNormalization` |
|
116
|
|
|
applications transformed to use minibatch statistics instead |
|
117
|
|
|
of accumulated population statistics. |
|
118
|
|
|
|
|
119
|
|
|
See Also |
|
120
|
|
|
-------- |
|
121
|
|
|
:func:`batch_normalization`, for an alternative method to produce |
|
122
|
|
|
batch normalized graphs. |
|
123
|
|
|
|
|
124
|
|
|
Examples |
|
125
|
|
|
-------- |
|
126
|
|
|
First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
|
127
|
|
|
|
|
128
|
|
|
>>> import theano |
|
129
|
|
|
>>> from blocks.bricks import BatchNormalizedMLP, Tanh |
|
130
|
|
|
>>> from blocks.initialization import Constant, IsotropicGaussian |
|
131
|
|
|
>>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
|
132
|
|
|
... weights_init=IsotropicGaussian(0.1), |
|
133
|
|
|
... biases_init=Constant(0)) |
|
134
|
|
|
>>> mlp.initialize() |
|
135
|
|
|
|
|
136
|
|
|
Now, we'll construct an output variable as we would normally. This |
|
137
|
|
|
is getting normalized by the *population* statistics, which by |
|
138
|
|
|
default are initialized to 0 (mean) and 1 (standard deviation), |
|
139
|
|
|
respectively. |
|
140
|
|
|
|
|
141
|
|
|
>>> x = theano.tensor.matrix() |
|
142
|
|
|
>>> y = mlp.apply(x) |
|
143
|
|
|
|
|
144
|
|
|
Finally, we'll create a :class:`~blocks.graph.ComputationGraph` |
|
145
|
|
|
and transform it to switch to minibatch standardization: |
|
146
|
|
|
|
|
147
|
|
|
>>> from blocks.graph import ComputationGraph |
|
148
|
|
|
>>> cg = apply_batch_normalization(ComputationGraph([y])) |
|
149
|
|
|
>>> y_bn = cg.outputs[0] |
|
150
|
|
|
|
|
151
|
|
|
Let's verify that these two graphs behave differently on the |
|
152
|
|
|
same data: |
|
153
|
|
|
|
|
154
|
|
|
>>> import numpy |
|
155
|
|
|
>>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
|
156
|
|
|
>>> inf_y = y.eval({x: data}) |
|
157
|
|
|
>>> trn_y = y_bn.eval({x: data}) |
|
158
|
|
|
>>> numpy.allclose(inf_y, trn_y) |
|
159
|
|
|
False |
|
160
|
|
|
|
|
161
|
|
|
""" |
|
162
|
|
|
# Avoid circular imports. |
|
163
|
|
|
from blocks.bricks import BatchNormalization |
|
164
|
|
|
from ..filter import VariableFilter, get_application_call |
|
165
|
|
|
|
|
166
|
|
|
# Create filters for variables involved in a batch normalization brick |
|
167
|
|
|
# application. |
|
168
|
|
|
def make_variable_filter(role): |
|
169
|
|
|
return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
|
170
|
|
|
|
|
171
|
|
|
# Group inputs and outputs into dicts indexed by application call. |
|
172
|
|
|
def get_app_call_dict(variable_filter): |
|
173
|
|
|
return collections.OrderedDict((get_application_call(v), v) for v in |
|
174
|
|
|
variable_filter(computation_graph)) |
|
175
|
|
|
|
|
176
|
|
|
# Compose these two so that we get 4 dicts, grouped by application |
|
177
|
|
|
# call, of different variable roles involved in BatchNormalization. |
|
178
|
|
|
inputs, outputs, means, stdevs = map(get_app_call_dict, |
|
179
|
|
|
map(make_variable_filter, |
|
180
|
|
|
[INPUT, OUTPUT, BATCH_NORM_OFFSET, |
|
181
|
|
|
BATCH_NORM_DIVISOR])) |
|
182
|
|
|
|
|
183
|
|
|
assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
|
184
|
|
|
|
|
185
|
|
|
# Remove any ApplicationCalls that were not generated by apply(), or |
|
186
|
|
|
# were generated by an apply() while already in training mode. |
|
187
|
|
|
app_calls = inputs.keys() |
|
188
|
|
|
remove = _training_mode_application_calls(app_calls) |
|
189
|
|
|
for app_call in app_calls: |
|
190
|
|
|
if app_call in remove: |
|
191
|
|
|
for mapping in (inputs, outputs, means, stdevs): |
|
192
|
|
|
del mapping[app_call] |
|
193
|
|
|
|
|
194
|
|
|
replacements = [] |
|
195
|
|
|
for app_call in inputs: |
|
196
|
|
|
old_output = outputs[app_call] |
|
197
|
|
|
# Get rid of the copy made on the way into the original apply. |
|
198
|
|
|
op = inputs[app_call].owner.op |
|
199
|
|
|
assert (isinstance(op, theano.tensor.Elemwise) and |
|
200
|
|
|
isinstance(op.scalar_op, theano.scalar.basic.Identity)) |
|
201
|
|
|
unpacked = inputs[app_call].owner.inputs[0] |
|
202
|
|
|
with app_call.application.brick: |
|
203
|
|
|
new_output = app_call.application.brick.apply(unpacked) |
|
204
|
|
|
new_app_call = get_application_call(new_output) |
|
205
|
|
|
assert new_app_call.metadata['training_mode'] |
|
206
|
|
|
replacements.append((old_output, new_output)) |
|
207
|
|
|
return computation_graph.replace(replacements) |
|
208
|
|
|
|
|
209
|
|
|
|
|
210
|
|
|
def get_batch_normalization_updates(training_graph, allow_duplicates=False): |
|
211
|
|
|
"""Extract correspondences for learning BN population statistics. |
|
212
|
|
|
|
|
213
|
|
|
Parameters |
|
214
|
|
|
---------- |
|
215
|
|
|
training_graph : :class:`~blocks.graph.ComputationGraph` |
|
216
|
|
|
A graph of expressions wherein "training mode" batch normalization |
|
217
|
|
|
is taking place. |
|
218
|
|
|
allow_duplicates : bool, optional |
|
219
|
|
|
If `True`, allow multiple training-mode application calls from the |
|
220
|
|
|
same :class:`~blocks.bricks.BatchNormalization` instance, and |
|
221
|
|
|
return pairs corresponding to all of them. It's then the user's |
|
222
|
|
|
responsibility to do something sensible to resolve the duplicates. |
|
223
|
|
|
|
|
224
|
|
|
Returns |
|
225
|
|
|
------- |
|
226
|
|
|
update_pairs : list of tuples |
|
227
|
|
|
A list of 2-tuples where the first element of each tuple is the |
|
228
|
|
|
shared variable containing a "population" mean or standard |
|
229
|
|
|
deviation, and the second is a Theano variable for the |
|
230
|
|
|
corresponding statistics on a minibatch. Note that multiple |
|
231
|
|
|
applications of a single :class:`blocks.bricks.BatchNormalization` |
|
232
|
|
|
may appear in the graph, and therefore (if `allow_duplicates` is |
|
233
|
|
|
True) a single population variable may map to several different |
|
234
|
|
|
minibatch variables, and appear multiple times in this mapping. |
|
235
|
|
|
This can happen in recurrent models, siamese networks or other |
|
236
|
|
|
models that reuse pathways. |
|
237
|
|
|
|
|
238
|
|
|
Notes |
|
239
|
|
|
----- |
|
240
|
|
|
Used in their raw form, these updates will simply overwrite the |
|
241
|
|
|
population statistics with the minibatch statistics at every gradient |
|
242
|
|
|
step. You will probably want to transform these pairs into something |
|
243
|
|
|
more sensible, such as keeping a moving average of minibatch values, |
|
244
|
|
|
or accumulating an average over the entire training set once every few |
|
245
|
|
|
epochs. |
|
246
|
|
|
|
|
247
|
|
|
""" |
|
248
|
|
|
from ..bricks import BatchNormalization |
|
249
|
|
|
from ..filter import VariableFilter, get_application_call |
|
250
|
|
|
var_filter = VariableFilter(bricks=[BatchNormalization], roles=[OUTPUT]) |
|
251
|
|
|
all_app_calls = map(get_application_call, var_filter(training_graph)) |
|
252
|
|
|
train_app_calls = _training_mode_application_calls(all_app_calls) |
|
253
|
|
|
if len(train_app_calls) == 0: |
|
254
|
|
|
raise ValueError("no training mode BatchNormalization " |
|
255
|
|
|
"applications found in graph") |
|
256
|
|
|
bricks = [c.application.brick for c in train_app_calls] |
|
257
|
|
|
|
|
258
|
|
|
if not allow_duplicates and not isdistinct(bricks): |
|
259
|
|
|
raise ValueError('multiple applications of the same ' |
|
260
|
|
|
'BatchNormalization brick; pass allow_duplicates ' |
|
261
|
|
|
'= True to override this check') |
|
262
|
|
|
|
|
263
|
|
|
def extract_pair(brick_attribute, metadata_key, app_call): |
|
264
|
|
|
return (getattr(app_call.application.brick, brick_attribute), |
|
265
|
|
|
app_call.metadata[metadata_key]) |
|
266
|
|
|
|
|
267
|
|
|
mean_pair = partial(extract_pair, 'population_mean', 'offset') |
|
268
|
|
|
stdev_pair = partial(extract_pair, 'population_stdev', 'divisor') |
|
269
|
|
|
return sum([[mean_pair(a), stdev_pair(a)] |
|
270
|
|
|
if not a.application.brick.mean_only |
|
271
|
|
|
else [mean_pair(a)] |
|
272
|
|
|
for a in train_app_calls], []) |
|
273
|
|
|
|
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.