1 | """Implements the batch normalization training graph transform. |
||
2 | |||
3 | Specifically, this module contains the implementation for the |
||
4 | transformation of a batch-normalized inference graph into training graph, |
||
5 | which uses minibatch statistics in place of population statistics. |
||
6 | |||
7 | """ |
||
8 | import collections |
||
9 | import contextlib |
||
10 | from functools import partial |
||
11 | |||
12 | import theano |
||
13 | from toolz import isdistinct |
||
14 | |||
15 | from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT |
||
16 | from ..utils import find_bricks |
||
17 | |||
18 | |||
19 | def _training_mode_application_calls(application_calls): |
||
0 ignored issues
–
show
|
|||
20 | """Filter for application calls made in 'training mode'.""" |
||
21 | from ..bricks import BatchNormalization |
||
22 | out = [] |
||
23 | for app_call in application_calls: |
||
24 | assert isinstance(app_call.application.brick, BatchNormalization) |
||
25 | assert app_call.application.application == BatchNormalization.apply |
||
26 | if app_call.metadata.get('training_mode', False): |
||
27 | out.append(app_call) |
||
28 | return out |
||
29 | |||
30 | |||
31 | @contextlib.contextmanager |
||
32 | def batch_normalization(*bricks): |
||
33 | r"""Context manager to run batch normalization in "training mode". |
||
34 | |||
35 | Parameters |
||
36 | ---------- |
||
37 | \*bricks |
||
38 | One or more bricks which will be inspected for descendant |
||
39 | instances of :class:`~blocks.bricks.BatchNormalization`. |
||
40 | |||
41 | Notes |
||
42 | ----- |
||
43 | Graph replacement using :func:`apply_batch_normalization`, while |
||
44 | elegant, can lead to Theano graphs that are quite large and result |
||
45 | in very slow compiles. This provides an alternative mechanism for |
||
46 | building the batch normalized training graph. It can be somewhat |
||
47 | less convenient as it requires building the graph twice if one |
||
48 | wishes to monitor the output of the inference graph during training. |
||
49 | |||
50 | Examples |
||
51 | -------- |
||
52 | First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
||
53 | |||
54 | >>> import theano |
||
55 | >>> from blocks.bricks import BatchNormalizedMLP, Tanh |
||
56 | >>> from blocks.initialization import Constant, IsotropicGaussian |
||
57 | >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
||
58 | ... weights_init=IsotropicGaussian(0.1), |
||
59 | ... biases_init=Constant(0)) |
||
60 | >>> mlp.initialize() |
||
61 | |||
62 | Now, we'll construct an output variable as we would normally. This |
||
63 | is getting normalized by the *population* statistics, which by |
||
64 | default are initialized to 0 (mean) and 1 (standard deviation), |
||
65 | respectively. |
||
66 | |||
67 | >>> x = theano.tensor.matrix() |
||
68 | >>> y = mlp.apply(x) |
||
69 | |||
70 | And now, to construct an output with batch normalization enabled, |
||
71 | i.e. normalizing pre-activations using per-minibatch statistics, we |
||
72 | simply make a similar call inside of a `with` statement: |
||
73 | |||
74 | >>> with batch_normalization(mlp): |
||
75 | ... y_bn = mlp.apply(x) |
||
76 | |||
77 | Let's verify that these two graphs behave differently on the |
||
78 | same data: |
||
79 | |||
80 | >>> import numpy |
||
81 | >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
||
82 | >>> inf_y = y.eval({x: data}) |
||
83 | >>> trn_y = y_bn.eval({x: data}) |
||
84 | >>> numpy.allclose(inf_y, trn_y) |
||
85 | False |
||
86 | |||
87 | """ |
||
88 | # Avoid circular imports. |
||
89 | from blocks.bricks import BatchNormalization |
||
90 | |||
91 | bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization)) |
||
92 | # Can't use either nested() (deprecated) nor ExitStack (not available |
||
93 | # on Python 2.7). Well, that sucks. |
||
94 | try: |
||
95 | for brick in bn: |
||
96 | brick.__enter__() |
||
97 | yield |
||
98 | finally: |
||
99 | for brick in bn[::-1]: |
||
100 | brick.__exit__() |
||
101 | |||
102 | |||
103 | def apply_batch_normalization(computation_graph): |
||
104 | """Transform a graph into a batch-normalized training graph. |
||
105 | |||
106 | Parameters |
||
107 | ---------- |
||
108 | computation_graph : :class:`~blocks.graph.ComputationGraph` |
||
109 | The computation graph containing :class:`BatchNormalization` |
||
110 | brick applications. |
||
111 | |||
112 | Returns |
||
113 | ------- |
||
114 | batch_normed_graph : :class:`~blocks.graph.ComputationGraph` |
||
115 | The computation graph, with :class:`BatchNormalization` |
||
116 | applications transformed to use minibatch statistics instead |
||
117 | of accumulated population statistics. |
||
118 | |||
119 | See Also |
||
120 | -------- |
||
121 | :func:`batch_normalization`, for an alternative method to produce |
||
122 | batch normalized graphs. |
||
123 | |||
124 | Examples |
||
125 | -------- |
||
126 | First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`. |
||
127 | |||
128 | >>> import theano |
||
129 | >>> from blocks.bricks import BatchNormalizedMLP, Tanh |
||
130 | >>> from blocks.initialization import Constant, IsotropicGaussian |
||
131 | >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6], |
||
132 | ... weights_init=IsotropicGaussian(0.1), |
||
133 | ... biases_init=Constant(0)) |
||
134 | >>> mlp.initialize() |
||
135 | |||
136 | Now, we'll construct an output variable as we would normally. This |
||
137 | is getting normalized by the *population* statistics, which by |
||
138 | default are initialized to 0 (mean) and 1 (standard deviation), |
||
139 | respectively. |
||
140 | |||
141 | >>> x = theano.tensor.matrix() |
||
142 | >>> y = mlp.apply(x) |
||
143 | |||
144 | Finally, we'll create a :class:`~blocks.graph.ComputationGraph` |
||
145 | and transform it to switch to minibatch standardization: |
||
146 | |||
147 | >>> from blocks.graph import ComputationGraph |
||
148 | >>> cg = apply_batch_normalization(ComputationGraph([y])) |
||
149 | >>> y_bn = cg.outputs[0] |
||
150 | |||
151 | Let's verify that these two graphs behave differently on the |
||
152 | same data: |
||
153 | |||
154 | >>> import numpy |
||
155 | >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4) |
||
156 | >>> inf_y = y.eval({x: data}) |
||
157 | >>> trn_y = y_bn.eval({x: data}) |
||
158 | >>> numpy.allclose(inf_y, trn_y) |
||
159 | False |
||
160 | |||
161 | """ |
||
162 | # Avoid circular imports. |
||
163 | from blocks.bricks import BatchNormalization |
||
164 | from ..filter import VariableFilter, get_application_call |
||
165 | |||
166 | # Create filters for variables involved in a batch normalization brick |
||
167 | # application. |
||
168 | def make_variable_filter(role): |
||
169 | return VariableFilter(bricks=[BatchNormalization], roles=[role]) |
||
170 | |||
171 | # Group inputs and outputs into dicts indexed by application call. |
||
172 | def get_app_call_dict(variable_filter): |
||
173 | return collections.OrderedDict((get_application_call(v), v) for v in |
||
174 | variable_filter(computation_graph)) |
||
175 | |||
176 | # Compose these two so that we get 4 dicts, grouped by application |
||
177 | # call, of different variable roles involved in BatchNormalization. |
||
178 | inputs, outputs, means, stdevs = map(get_app_call_dict, |
||
179 | map(make_variable_filter, |
||
180 | [INPUT, OUTPUT, BATCH_NORM_OFFSET, |
||
181 | BATCH_NORM_DIVISOR])) |
||
182 | |||
183 | assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1 |
||
184 | |||
185 | # Remove any ApplicationCalls that were not generated by apply(), or |
||
186 | # were generated by an apply() while already in training mode. |
||
187 | app_calls = inputs.keys() |
||
188 | remove = _training_mode_application_calls(app_calls) |
||
189 | for app_call in app_calls: |
||
190 | if app_call in remove: |
||
191 | for mapping in (inputs, outputs, means, stdevs): |
||
192 | del mapping[app_call] |
||
193 | |||
194 | replacements = [] |
||
195 | for app_call in inputs: |
||
196 | old_output = outputs[app_call] |
||
197 | # Get rid of the copy made on the way into the original apply. |
||
198 | op = inputs[app_call].owner.op |
||
199 | assert (isinstance(op, theano.tensor.Elemwise) and |
||
200 | isinstance(op.scalar_op, theano.scalar.basic.Identity)) |
||
201 | unpacked = inputs[app_call].owner.inputs[0] |
||
202 | with app_call.application.brick: |
||
203 | new_output = app_call.application.brick.apply(unpacked) |
||
204 | new_app_call = get_application_call(new_output) |
||
205 | assert new_app_call.metadata['training_mode'] |
||
206 | replacements.append((old_output, new_output)) |
||
207 | return computation_graph.replace(replacements) |
||
208 | |||
209 | |||
210 | def get_batch_normalization_updates(training_graph, allow_duplicates=False): |
||
211 | """Extract correspondences for learning BN population statistics. |
||
212 | |||
213 | Parameters |
||
214 | ---------- |
||
215 | training_graph : :class:`~blocks.graph.ComputationGraph` |
||
216 | A graph of expressions wherein "training mode" batch normalization |
||
217 | is taking place. |
||
218 | allow_duplicates : bool, optional |
||
219 | If `True`, allow multiple training-mode application calls from the |
||
220 | same :class:`~blocks.bricks.BatchNormalization` instance, and |
||
221 | return pairs corresponding to all of them. It's then the user's |
||
222 | responsibility to do something sensible to resolve the duplicates. |
||
223 | |||
224 | Returns |
||
225 | ------- |
||
226 | update_pairs : list of tuples |
||
227 | A list of 2-tuples where the first element of each tuple is the |
||
228 | shared variable containing a "population" mean or standard |
||
229 | deviation, and the second is a Theano variable for the |
||
230 | corresponding statistics on a minibatch. Note that multiple |
||
231 | applications of a single :class:`blocks.bricks.BatchNormalization` |
||
232 | may appear in the graph, and therefore (if `allow_duplicates` is |
||
233 | True) a single population variable may map to several different |
||
234 | minibatch variables, and appear multiple times in this mapping. |
||
235 | This can happen in recurrent models, siamese networks or other |
||
236 | models that reuse pathways. |
||
237 | |||
238 | Notes |
||
239 | ----- |
||
240 | Used in their raw form, these updates will simply overwrite the |
||
241 | population statistics with the minibatch statistics at every gradient |
||
242 | step. You will probably want to transform these pairs into something |
||
243 | more sensible, such as keeping a moving average of minibatch values, |
||
244 | or accumulating an average over the entire training set once every few |
||
245 | epochs. |
||
246 | |||
247 | """ |
||
248 | from ..bricks import BatchNormalization |
||
249 | from ..filter import VariableFilter, get_application_call |
||
250 | var_filter = VariableFilter(bricks=[BatchNormalization], roles=[OUTPUT]) |
||
251 | all_app_calls = map(get_application_call, var_filter(training_graph)) |
||
252 | train_app_calls = _training_mode_application_calls(all_app_calls) |
||
253 | if len(train_app_calls) == 0: |
||
254 | raise ValueError("no training mode BatchNormalization " |
||
255 | "applications found in graph") |
||
256 | bricks = [c.application.brick for c in train_app_calls] |
||
257 | |||
258 | if not allow_duplicates and not isdistinct(bricks): |
||
259 | raise ValueError('multiple applications of the same ' |
||
260 | 'BatchNormalization brick; pass allow_duplicates ' |
||
261 | '= True to override this check') |
||
262 | |||
263 | def extract_pair(brick_attribute, metadata_key, app_call): |
||
264 | return (getattr(app_call.application.brick, brick_attribute), |
||
265 | app_call.metadata[metadata_key]) |
||
266 | |||
267 | mean_pair = partial(extract_pair, 'population_mean', 'offset') |
||
268 | stdev_pair = partial(extract_pair, 'population_stdev', 'divisor') |
||
269 | return sum([[mean_pair(a), stdev_pair(a)] |
||
270 | if not a.application.brick.mean_only |
||
271 | else [mean_pair(a)] |
||
272 | for a in train_app_calls], []) |
||
273 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.