|
1
|
|
|
"""Implements the batch normalization training graph transform. |
|
2
|
|
|
|
|
3
|
|
|
Specifically, this module contains the implementation for the |
|
4
|
|
|
transformation of a batch-normalized inference graph into training graph, |
|
5
|
|
|
which uses minibatch statistics in place of population statistics. |
|
6
|
|
|
|
|
7
|
|
|
""" |
|
8
|
|
|
import collections |
|
9
|
|
|
import contextlib |
|
10
|
|
|
|
|
11
|
|
|
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT |
|
12
|
|
|
from ..utils import find_bricks |
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
@contextlib.contextmanager |
|
16
|
|
|
def batch_normalization(*bricks): |
|
17
|
|
|
r"""Context manager to run batch normalization in "training mode". |
|
18
|
|
|
|
|
19
|
|
|
Parameters |
|
20
|
|
|
---------- |
|
21
|
|
|
\*bricks |
|
22
|
|
|
One or more bricks which will be inspected for descendant |
|
23
|
|
|
instances of :class:`~blocks.bricks.BatchNormalization`. |
|
24
|
|
|
|
|
25
|
|
|
Notes |
|
26
|
|
|
----- |
|
27
|
|
|
Graph replacement using :func:`apply_batch_normalization`, while |
|
28
|
|
|
elegant, can lead to Theano graphs that are quite large and result |
|
29
|
|
|
in very slow compiles. This provides an alternative mechanism for |
|
30
|
|
|
building the batch normalized training graph. It can be somewhat |
|
31
|
|
|
less convenient as it requires building the graph twice if one |
|
32
|
|
|
wishes to monitor the output of the inference graph during training. |
|
33
|
|
|
|
|
34
|
|
|
Examples |
|
35
|
|
|
-------- |
|
36
|
|
|
First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
|
37
|
|
|
This behaves almost exactly like a regular :class:`~blocks.bricks.MLP` |
|
38
|
|
|
except that it contains :class:`~blocks.bricks.BatchNormalization` |
|
39
|
|
|
bricks placed before each activation function. |
|
40
|
|
|
|
|
41
|
|
|
>>> import theano |
|
42
|
|
|
>>> from blocks.bricks import BatchNormalizedMLP, Tanh |
|
43
|
|
|
>>> from blocks.initialization import Constant, IsotropicGaussian |
|
44
|
|
|
>>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
|
45
|
|
|
... weights_init=IsotropicGaussian(0.1), |
|
46
|
|
|
... biases_init=Constant(0)) |
|
47
|
|
|
>>> mlp.initialize() |
|
48
|
|
|
>>> x = theano.tensor.matrix('x') |
|
49
|
|
|
|
|
50
|
|
|
First, we'll construct an output variable as we would normally. This |
|
51
|
|
|
is getting normalized by the *population* statistics, which by |
|
52
|
|
|
default are initialized to 0 (mean) and 1 (standard deviation), |
|
53
|
|
|
respectively. |
|
54
|
|
|
|
|
55
|
|
|
>>> y = mlp.apply(x) |
|
56
|
|
|
|
|
57
|
|
|
And now, to construct an output with batch normalization enabled, |
|
58
|
|
|
i.e. normalizing pre-activations using per-minibatch statistics, we |
|
59
|
|
|
simply make a similar call inside of a `with` statement: |
|
60
|
|
|
|
|
61
|
|
|
>>> with batch_normalization(mlp): |
|
62
|
|
|
... y_bn = mlp.apply(x) |
|
63
|
|
|
|
|
64
|
|
|
Let's verify that these two graphs behave differently on the |
|
65
|
|
|
same data: |
|
66
|
|
|
|
|
67
|
|
|
>>> import numpy |
|
68
|
|
|
>>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
|
69
|
|
|
>>> inf_y = y.eval({x: data}) |
|
70
|
|
|
>>> trn_y = y_bn.eval({x: data}) |
|
71
|
|
|
>>> numpy.allclose(inf_y, trn_y) |
|
72
|
|
|
False |
|
73
|
|
|
|
|
74
|
|
|
""" |
|
75
|
|
|
# Avoid circular imports. |
|
76
|
|
|
from blocks.bricks import BatchNormalization |
|
77
|
|
|
|
|
78
|
|
|
bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization)) |
|
79
|
|
|
# Can't use either nested() (deprecated) nor ExitStack (not available |
|
80
|
|
|
# on Python 2.7). Well, that sucks. |
|
81
|
|
|
for brick in bn: |
|
82
|
|
|
brick.__enter__() |
|
83
|
|
|
yield |
|
84
|
|
|
for brick in bn: |
|
85
|
|
|
brick.__exit__() |
|
86
|
|
|
|
|
87
|
|
|
|
|
88
|
|
|
def apply_batch_normalization(computation_graph): |
|
89
|
|
|
"""Transform a graph into a batch-normalized training graph. |
|
90
|
|
|
|
|
91
|
|
|
Parameters |
|
92
|
|
|
---------- |
|
93
|
|
|
computation_graph : instance of :class:`ComputationGraph` |
|
94
|
|
|
The computation graph containing :class:`BatchNormalization` |
|
95
|
|
|
brick applications. |
|
96
|
|
|
|
|
97
|
|
|
Returns |
|
98
|
|
|
------- |
|
99
|
|
|
batch_normed_computation_graph : instance of :class:`ComputationGraph` |
|
100
|
|
|
The computation graph, with :class:`BatchNormalization` |
|
101
|
|
|
applications transformed to use minibatch statistics instead |
|
102
|
|
|
of accumulated population statistics. |
|
103
|
|
|
update_pairs : list of tuples |
|
104
|
|
|
A list of 2-tuples where the first element of each tuple is the |
|
105
|
|
|
shared variable containing a "population" mean or standard |
|
106
|
|
|
deviation, and the second is a Theano variable for the |
|
107
|
|
|
corresponding statistics on a minibatch. Note that multiple |
|
108
|
|
|
applications of a single :class:`blocks.bricks.BatchNormalization` |
|
109
|
|
|
may appear in the graph, and therefore a single population variable |
|
110
|
|
|
may map to several different minibatch variables. |
|
111
|
|
|
|
|
112
|
|
|
See Also |
|
113
|
|
|
-------- |
|
114
|
|
|
:func:`batch_normalization`, for an alternative method to produce |
|
115
|
|
|
batch normalized graphs. |
|
116
|
|
|
|
|
117
|
|
|
""" |
|
118
|
|
|
# Avoid circular imports. |
|
119
|
|
|
from blocks.bricks import BatchNormalization |
|
120
|
|
|
from ..filter import VariableFilter, get_application_call |
|
121
|
|
|
|
|
122
|
|
|
# Create filters for variables involved in a batch normalization brick |
|
123
|
|
|
# application. |
|
124
|
|
|
def make_variable_filter(role): |
|
125
|
|
|
return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
|
126
|
|
|
|
|
127
|
|
|
# Group inputs and outputs into dicts indexed by application call. |
|
128
|
|
|
def get_app_call_dict(variable_filter): |
|
129
|
|
|
return collections.OrderedDict((get_application_call(v), v) for v in |
|
130
|
|
|
variable_filter(computation_graph)) |
|
131
|
|
|
|
|
132
|
|
|
# Compose these two so that we get 4 dicts, grouped by application |
|
133
|
|
|
# call, of different variable roles involved in BatchNormalization. |
|
134
|
|
|
inputs, outputs, means, stdevs = map(get_app_call_dict, |
|
135
|
|
|
map(make_variable_filter, |
|
136
|
|
|
[INPUT, OUTPUT, BATCH_NORM_OFFSET, |
|
137
|
|
|
BATCH_NORM_DIVISOR])) |
|
138
|
|
|
|
|
139
|
|
|
assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
|
140
|
|
|
|
|
141
|
|
|
# Remove any ApplicationCalls that were not generated by apply(), or |
|
142
|
|
|
# were generated by an apply() while already in training mode. |
|
143
|
|
|
remove = filter(lambda a: (a.metadata.get('training_mode', False) or |
|
144
|
|
|
a.application.application != |
|
145
|
|
|
BatchNormalization.apply), inputs.keys()) |
|
146
|
|
|
for app_call in remove: |
|
147
|
|
|
for mapping in (inputs, outputs, means, stdevs): |
|
148
|
|
|
del mapping[app_call] |
|
149
|
|
|
|
|
150
|
|
|
replacements = [] |
|
151
|
|
|
update_pairs = [] |
|
152
|
|
|
for app_call in inputs: |
|
153
|
|
|
old_output = outputs[app_call] |
|
154
|
|
|
unpacked = inputs[app_call].owner.inputs[0] |
|
155
|
|
|
with app_call.application.brick: |
|
156
|
|
|
new_output = app_call.application.brick.apply(unpacked) |
|
157
|
|
|
replacements.append((old_output, new_output)) |
|
158
|
|
|
new_app_call = get_application_call(new_output) |
|
159
|
|
|
update_pairs.append((app_call.application.brick.population_mean, |
|
160
|
|
|
new_app_call.metadata['offset'])) |
|
161
|
|
|
update_pairs.append((app_call.application.brick.population_stdev, |
|
162
|
|
|
new_app_call.metadata['divisor'])) |
|
163
|
|
|
return computation_graph.replace(replacements), update_pairs |
|
164
|
|
|
|