1
|
|
|
"""Implements the batch normalization training graph transform. |
2
|
|
|
|
3
|
|
|
Specifically, this module contains the implementation for the |
4
|
|
|
transformation of a batch-normalized inference graph into training graph, |
5
|
|
|
which uses minibatch statistics in place of population statistics. |
6
|
|
|
|
7
|
|
|
""" |
8
|
|
|
import collections |
9
|
|
|
import contextlib |
10
|
|
|
from functools import partial |
11
|
|
|
|
12
|
|
|
import theano |
13
|
|
|
from toolz import isdistinct |
14
|
|
|
|
15
|
|
|
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT |
16
|
|
|
from ..utils import find_bricks |
17
|
|
|
|
18
|
|
|
|
19
|
|
|
def _training_mode_application_calls(application_calls): |
|
|
|
|
20
|
|
|
"""Filter for application calls made in 'training mode'.""" |
21
|
|
|
from ..bricks import BatchNormalization |
22
|
|
|
out = [] |
23
|
|
|
for app_call in application_calls: |
24
|
|
|
assert isinstance(app_call.application.brick, BatchNormalization) |
25
|
|
|
assert app_call.application.application == BatchNormalization.apply |
26
|
|
|
if app_call.metadata.get('training_mode', False): |
27
|
|
|
out.append(app_call) |
28
|
|
|
return out |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
@contextlib.contextmanager |
32
|
|
|
def batch_normalization(*bricks): |
33
|
|
|
r"""Context manager to run batch normalization in "training mode". |
34
|
|
|
|
35
|
|
|
Parameters |
36
|
|
|
---------- |
37
|
|
|
\*bricks |
38
|
|
|
One or more bricks which will be inspected for descendant |
39
|
|
|
instances of :class:`~blocks.bricks.BatchNormalization`. |
40
|
|
|
|
41
|
|
|
Notes |
42
|
|
|
----- |
43
|
|
|
Graph replacement using :func:`apply_batch_normalization`, while |
44
|
|
|
elegant, can lead to Theano graphs that are quite large and result |
45
|
|
|
in very slow compiles. This provides an alternative mechanism for |
46
|
|
|
building the batch normalized training graph. It can be somewhat |
47
|
|
|
less convenient as it requires building the graph twice if one |
48
|
|
|
wishes to monitor the output of the inference graph during training. |
49
|
|
|
|
50
|
|
|
Examples |
51
|
|
|
-------- |
52
|
|
|
First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
53
|
|
|
|
54
|
|
|
>>> import theano |
55
|
|
|
>>> from blocks.bricks import BatchNormalizedMLP, Tanh |
56
|
|
|
>>> from blocks.initialization import Constant, IsotropicGaussian |
57
|
|
|
>>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
58
|
|
|
... weights_init=IsotropicGaussian(0.1), |
59
|
|
|
... biases_init=Constant(0)) |
60
|
|
|
>>> mlp.initialize() |
61
|
|
|
|
62
|
|
|
Now, we'll construct an output variable as we would normally. This |
63
|
|
|
is getting normalized by the *population* statistics, which by |
64
|
|
|
default are initialized to 0 (mean) and 1 (standard deviation), |
65
|
|
|
respectively. |
66
|
|
|
|
67
|
|
|
>>> x = theano.tensor.matrix() |
68
|
|
|
>>> y = mlp.apply(x) |
69
|
|
|
|
70
|
|
|
And now, to construct an output with batch normalization enabled, |
71
|
|
|
i.e. normalizing pre-activations using per-minibatch statistics, we |
72
|
|
|
simply make a similar call inside of a `with` statement: |
73
|
|
|
|
74
|
|
|
>>> with batch_normalization(mlp): |
75
|
|
|
... y_bn = mlp.apply(x) |
76
|
|
|
|
77
|
|
|
Let's verify that these two graphs behave differently on the |
78
|
|
|
same data: |
79
|
|
|
|
80
|
|
|
>>> import numpy |
81
|
|
|
>>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
82
|
|
|
>>> inf_y = y.eval({x: data}) |
83
|
|
|
>>> trn_y = y_bn.eval({x: data}) |
84
|
|
|
>>> numpy.allclose(inf_y, trn_y) |
85
|
|
|
False |
86
|
|
|
|
87
|
|
|
""" |
88
|
|
|
# Avoid circular imports. |
89
|
|
|
from blocks.bricks import BatchNormalization |
90
|
|
|
|
91
|
|
|
bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization)) |
92
|
|
|
# Can't use either nested() (deprecated) nor ExitStack (not available |
93
|
|
|
# on Python 2.7). Well, that sucks. |
94
|
|
|
try: |
95
|
|
|
for brick in bn: |
96
|
|
|
brick.__enter__() |
97
|
|
|
yield |
98
|
|
|
finally: |
99
|
|
|
for brick in bn[::-1]: |
100
|
|
|
brick.__exit__() |
|
|
|
|
101
|
|
|
|
102
|
|
|
|
103
|
|
|
def apply_batch_normalization(computation_graph): |
104
|
|
|
"""Transform a graph into a batch-normalized training graph. |
105
|
|
|
|
106
|
|
|
Parameters |
107
|
|
|
---------- |
108
|
|
|
computation_graph : :class:`~blocks.graph.ComputationGraph` |
109
|
|
|
The computation graph containing :class:`BatchNormalization` |
110
|
|
|
brick applications. |
111
|
|
|
|
112
|
|
|
Returns |
113
|
|
|
------- |
114
|
|
|
batch_normed_graph : :class:`~blocks.graph.ComputationGraph` |
115
|
|
|
The computation graph, with :class:`BatchNormalization` |
116
|
|
|
applications transformed to use minibatch statistics instead |
117
|
|
|
of accumulated population statistics. |
118
|
|
|
|
119
|
|
|
See Also |
120
|
|
|
-------- |
121
|
|
|
:func:`batch_normalization`, for an alternative method to produce |
122
|
|
|
batch normalized graphs. |
123
|
|
|
|
124
|
|
|
Examples |
125
|
|
|
-------- |
126
|
|
|
First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
127
|
|
|
|
128
|
|
|
>>> import theano |
129
|
|
|
>>> from blocks.bricks import BatchNormalizedMLP, Tanh |
130
|
|
|
>>> from blocks.initialization import Constant, IsotropicGaussian |
131
|
|
|
>>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
132
|
|
|
... weights_init=IsotropicGaussian(0.1), |
133
|
|
|
... biases_init=Constant(0)) |
134
|
|
|
>>> mlp.initialize() |
135
|
|
|
|
136
|
|
|
Now, we'll construct an output variable as we would normally. This |
137
|
|
|
is getting normalized by the *population* statistics, which by |
138
|
|
|
default are initialized to 0 (mean) and 1 (standard deviation), |
139
|
|
|
respectively. |
140
|
|
|
|
141
|
|
|
>>> x = theano.tensor.matrix() |
142
|
|
|
>>> y = mlp.apply(x) |
143
|
|
|
|
144
|
|
|
Finally, we'll create a :class:`~blocks.graph.ComputationGraph` |
145
|
|
|
and transform it to switch to minibatch standardization: |
146
|
|
|
|
147
|
|
|
>>> from blocks.graph import ComputationGraph |
148
|
|
|
>>> cg = apply_batch_normalization(ComputationGraph([y])) |
149
|
|
|
>>> y_bn = cg.outputs[0] |
150
|
|
|
|
151
|
|
|
Let's verify that these two graphs behave differently on the |
152
|
|
|
same data: |
153
|
|
|
|
154
|
|
|
>>> import numpy |
155
|
|
|
>>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
156
|
|
|
>>> inf_y = y.eval({x: data}) |
157
|
|
|
>>> trn_y = y_bn.eval({x: data}) |
158
|
|
|
>>> numpy.allclose(inf_y, trn_y) |
159
|
|
|
False |
160
|
|
|
|
161
|
|
|
""" |
162
|
|
|
# Avoid circular imports. |
163
|
|
|
from blocks.bricks import BatchNormalization |
164
|
|
|
from ..filter import VariableFilter, get_application_call |
165
|
|
|
|
166
|
|
|
# Create filters for variables involved in a batch normalization brick |
167
|
|
|
# application. |
168
|
|
|
def make_variable_filter(role): |
169
|
|
|
return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
170
|
|
|
|
171
|
|
|
# Group inputs and outputs into dicts indexed by application call. |
172
|
|
|
def get_app_call_dict(variable_filter): |
173
|
|
|
return collections.OrderedDict((get_application_call(v), v) for v in |
174
|
|
|
variable_filter(computation_graph)) |
175
|
|
|
|
176
|
|
|
# Compose these two so that we get 4 dicts, grouped by application |
177
|
|
|
# call, of different variable roles involved in BatchNormalization. |
178
|
|
|
inputs, outputs, means, stdevs = map(get_app_call_dict, |
179
|
|
|
map(make_variable_filter, |
180
|
|
|
[INPUT, OUTPUT, BATCH_NORM_OFFSET, |
181
|
|
|
BATCH_NORM_DIVISOR])) |
182
|
|
|
|
183
|
|
|
assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
184
|
|
|
|
185
|
|
|
# Remove any ApplicationCalls that were not generated by apply(), or |
186
|
|
|
# were generated by an apply() while already in training mode. |
187
|
|
|
app_calls = inputs.keys() |
188
|
|
|
remove = _training_mode_application_calls(app_calls) |
189
|
|
|
for app_call in app_calls: |
190
|
|
|
if app_call in remove: |
191
|
|
|
for mapping in (inputs, outputs, means, stdevs): |
192
|
|
|
del mapping[app_call] |
193
|
|
|
|
194
|
|
|
replacements = [] |
195
|
|
|
for app_call in inputs: |
196
|
|
|
old_output = outputs[app_call] |
197
|
|
|
# Get rid of the copy made on the way into the original apply. |
198
|
|
|
op = inputs[app_call].owner.op |
199
|
|
|
assert (isinstance(op, theano.tensor.Elemwise) and |
200
|
|
|
isinstance(op.scalar_op, theano.scalar.basic.Identity)) |
201
|
|
|
unpacked = inputs[app_call].owner.inputs[0] |
202
|
|
|
with app_call.application.brick: |
203
|
|
|
new_output = app_call.application.brick.apply(unpacked) |
204
|
|
|
new_app_call = get_application_call(new_output) |
205
|
|
|
assert new_app_call.metadata['training_mode'] |
206
|
|
|
replacements.append((old_output, new_output)) |
207
|
|
|
return computation_graph.replace(replacements) |
208
|
|
|
|
209
|
|
|
|
210
|
|
|
def get_batch_normalization_updates(training_graph, allow_duplicates=False): |
211
|
|
|
"""Extract correspondences for learning BN population statistics. |
212
|
|
|
|
213
|
|
|
Parameters |
214
|
|
|
---------- |
215
|
|
|
training_graph : :class:`~blocks.graph.ComputationGraph` |
216
|
|
|
A graph of expressions wherein "training mode" batch normalization |
217
|
|
|
is taking place. |
218
|
|
|
allow_duplicates : bool, optional |
219
|
|
|
If `True`, allow multiple training-mode application calls from the |
220
|
|
|
same :class:`~blocks.bricks.BatchNormalization` instance, and |
221
|
|
|
return pairs corresponding to all of them. It's then the user's |
222
|
|
|
responsibility to do something sensible to resolve the duplicates. |
223
|
|
|
|
224
|
|
|
Returns |
225
|
|
|
------- |
226
|
|
|
update_pairs : list of tuples |
227
|
|
|
A list of 2-tuples where the first element of each tuple is the |
228
|
|
|
shared variable containing a "population" mean or standard |
229
|
|
|
deviation, and the second is a Theano variable for the |
230
|
|
|
corresponding statistics on a minibatch. Note that multiple |
231
|
|
|
applications of a single :class:`blocks.bricks.BatchNormalization` |
232
|
|
|
may appear in the graph, and therefore (if `allow_duplicates` is |
233
|
|
|
True) a single population variable may map to several different |
234
|
|
|
minibatch variables, and appear multiple times in this mapping. |
235
|
|
|
This can happen in recurrent models, siamese networks or other |
236
|
|
|
models that reuse pathways. |
237
|
|
|
|
238
|
|
|
Notes |
239
|
|
|
----- |
240
|
|
|
Used in their raw form, these updates will simply overwrite the |
241
|
|
|
population statistics with the minibatch statistics at every gradient |
242
|
|
|
step. You will probably want to transform these pairs into something |
243
|
|
|
more sensible, such as keeping a moving average of minibatch values, |
244
|
|
|
or accumulating an average over the entire training set once every few |
245
|
|
|
epochs. |
246
|
|
|
|
247
|
|
|
""" |
248
|
|
|
from ..bricks import BatchNormalization |
249
|
|
|
from ..filter import VariableFilter, get_application_call |
250
|
|
|
var_filter = VariableFilter(bricks=[BatchNormalization], roles=[OUTPUT]) |
251
|
|
|
all_app_calls = map(get_application_call, var_filter(training_graph)) |
252
|
|
|
train_app_calls = _training_mode_application_calls(all_app_calls) |
253
|
|
|
if len(train_app_calls) == 0: |
254
|
|
|
raise ValueError("no training mode BatchNormalization " |
255
|
|
|
"applications found in graph") |
256
|
|
|
bricks = [c.application.brick for c in train_app_calls] |
257
|
|
|
|
258
|
|
|
if not allow_duplicates and not isdistinct(bricks): |
259
|
|
|
raise ValueError('multiple applications of the same ' |
260
|
|
|
'BatchNormalization brick; pass allow_duplicates ' |
261
|
|
|
'= True to override this check') |
262
|
|
|
|
263
|
|
|
def extract_pair(brick_attribute, metadata_key, app_call): |
264
|
|
|
return (getattr(app_call.application.brick, brick_attribute), |
265
|
|
|
app_call.metadata[metadata_key]) |
266
|
|
|
|
267
|
|
|
mean_pair = partial(extract_pair, 'population_mean', 'offset') |
268
|
|
|
stdev_pair = partial(extract_pair, 'population_stdev', 'divisor') |
269
|
|
|
return sum([[mean_pair(a), stdev_pair(a)] |
270
|
|
|
if not a.application.brick.mean_only |
271
|
|
|
else [mean_pair(a)] |
272
|
|
|
for a in train_app_calls], []) |
273
|
|
|
|
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.