Completed
Pull Request — master (#941)
by David
07:45 queued 02:31
created

blocks.graph._training_mode_application_calls()   B

Complexity

Conditions 5

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 12
rs 8.5454
1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
from functools import partial
11
12
import theano
13
from toolz import isdistinct
14
15
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
16
from ..utils import find_bricks
17
18
19
def _training_mode_application_calls(application_calls):
0 ignored issues
show
Coding Style Naming introduced by
The name _training_mode_application_calls does not conform to the function naming conventions ([a-z_][a-z0-9_]{0,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
20
    """Filter for application calls made in 'training mode'."""
21
    from ..bricks import BatchNormalization
22
    out = []
23
    for app_call in application_calls:
24
        assert isinstance(app_call.application.brick, BatchNormalization)
25
        assert app_call.application.application == BatchNormalization.apply
26
        if app_call.metadata.get('training_mode', False):
27
            out.append(app_call)
28
        else:
29
            print(app_call, app_call.metadata)
30
    return out
31
32
33
@contextlib.contextmanager
34
def batch_normalization(*bricks):
35
    r"""Context manager to run batch normalization in "training mode".
36
37
    Parameters
38
    ----------
39
    \*bricks
40
        One or more bricks which will be inspected for descendant
41
        instances of :class:`~blocks.bricks.BatchNormalization`.
42
43
    Notes
44
    -----
45
    Graph replacement using :func:`apply_batch_normalization`, while
46
    elegant, can lead to Theano graphs that are quite large and result
47
    in very slow compiles. This provides an alternative mechanism for
48
    building the batch normalized training graph. It can be somewhat
49
    less convenient as it requires building the graph twice if one
50
    wishes to monitor the output of the inference graph during training.
51
52
    Examples
53
    --------
54
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
55
56
    >>> import theano
57
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
58
    >>> from blocks.initialization import Constant, IsotropicGaussian
59
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
60
    ...                          weights_init=IsotropicGaussian(0.1),
61
    ...                          biases_init=Constant(0))
62
    >>> mlp.initialize()
63
64
    Now, we'll construct an output variable as we would normally. This
65
    is getting normalized by the *population* statistics, which by
66
    default are initialized to 0 (mean) and 1 (standard deviation),
67
    respectively.
68
69
    >>> x = theano.tensor.matrix()
70
    >>> y = mlp.apply(x)
71
72
    And now, to construct an output with batch normalization enabled,
73
    i.e. normalizing pre-activations using per-minibatch statistics, we
74
    simply make a similar call inside of a `with` statement:
75
76
    >>> with batch_normalization(mlp):
77
    ...     y_bn = mlp.apply(x)
78
79
    Let's verify that these two graphs behave differently on the
80
    same data:
81
82
    >>> import numpy
83
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
84
    >>> inf_y = y.eval({x: data})
85
    >>> trn_y = y_bn.eval({x: data})
86
    >>> numpy.allclose(inf_y, trn_y)
87
    False
88
89
    """
90
    # Avoid circular imports.
91
    from blocks.bricks import BatchNormalization
92
93
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
94
    # Can't use either nested() (deprecated) nor ExitStack (not available
95
    # on Python 2.7). Well, that sucks.
96
    try:
97
        for brick in bn:
98
            brick.__enter__()
99
        yield
100
    finally:
101
        for brick in bn[::-1]:
102
            brick.__exit__()
0 ignored issues
show
Bug introduced by
The loop variable brick might not be defined here.
Loading history...
103
104
105
def apply_batch_normalization(computation_graph):
106
    """Transform a graph into a batch-normalized training graph.
107
108
    Parameters
109
    ----------
110
    computation_graph : :class:`~blocks.graph.ComputationGraph`
111
        The computation graph containing :class:`BatchNormalization`
112
        brick applications.
113
114
    Returns
115
    -------
116
    batch_normed_graph : :class:`~blocks.graph.ComputationGraph`
117
        The computation graph, with :class:`BatchNormalization`
118
        applications transformed to use minibatch statistics instead
119
        of accumulated population statistics.
120
121
    See Also
122
    --------
123
    :func:`batch_normalization`, for an alternative method to produce
124
    batch normalized graphs.
125
126
    Examples
127
    --------
128
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
129
130
    >>> import theano
131
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
132
    >>> from blocks.initialization import Constant, IsotropicGaussian
133
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
134
    ...                          weights_init=IsotropicGaussian(0.1),
135
    ...                          biases_init=Constant(0))
136
    >>> mlp.initialize()
137
138
    Now, we'll construct an output variable as we would normally. This
139
    is getting normalized by the *population* statistics, which by
140
    default are initialized to 0 (mean) and 1 (standard deviation),
141
    respectively.
142
143
    >>> x = theano.tensor.matrix()
144
    >>> y = mlp.apply(x)
145
146
    Finally, we'll create a :class:`~blocks.graph.ComputationGraph`
147
    and transform it to switch to minibatch standardization:
148
149
    >>> from blocks.graph import ComputationGraph
150
    >>> cg, _ = apply_batch_normalization(ComputationGraph([y]))
151
    >>> y_bn = cg.outputs[0]
152
153
    Let's verify that these two graphs behave differently on the
154
    same data:
155
156
    >>> import numpy
157
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
158
    >>> inf_y = y.eval({x: data})
159
    >>> trn_y = y_bn.eval({x: data})
160
    >>> numpy.allclose(inf_y, trn_y)
161
    False
162
163
    """
164
    # Avoid circular imports.
165
    from blocks.bricks import BatchNormalization
166
    from ..filter import VariableFilter, get_application_call
167
168
    # Create filters for variables involved in a batch normalization brick
169
    # application.
170
    def make_variable_filter(role):
171
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
172
173
    # Group inputs and outputs into dicts indexed by application call.
174
    def get_app_call_dict(variable_filter):
175
        return collections.OrderedDict((get_application_call(v), v) for v in
176
                                       variable_filter(computation_graph))
177
178
    # Compose these two so that we get 4 dicts, grouped by application
179
    # call, of different variable roles involved in BatchNormalization.
180
    inputs, outputs, means, stdevs = map(get_app_call_dict,
181
                                         map(make_variable_filter,
182
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
183
                                              BATCH_NORM_DIVISOR]))
184
185
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
186
187
    # Remove any ApplicationCalls that were not generated by apply(), or
188
    # were generated by an apply() while already in training mode.
189
    app_calls = inputs.keys()
190
    remove = _training_mode_application_calls(app_calls)
191
    for app_call in app_calls:
192
        if app_call in remove:
193
            for mapping in (inputs, outputs, means, stdevs):
194
                del mapping[app_call]
195
196
    replacements = []
197
    for app_call in inputs:
198
        old_output = outputs[app_call]
199
        # Get rid of the copy made on the way into the original apply.
200
        op = inputs[app_call].owner.op
201
        assert (isinstance(op, theano.tensor.Elemwise) and
202
                isinstance(op.scalar_op, theano.scalar.basic.Identity))
203
        unpacked = inputs[app_call].owner.inputs[0]
204
        with app_call.application.brick:
205
            new_output = app_call.application.brick.apply(unpacked)
206
            new_app_call = get_application_call(new_output)
207
            assert new_app_call.metadata['training_mode']
208
        replacements.append((old_output, new_output))
209
    return computation_graph.replace(replacements)
210
211
212
def batch_normalization_updates(training_graph, allow_duplicates=False):
213
    """Extract correspondences for learning BN population statistics.
214
215
    Parameters
216
    ----------
217
    training_graph : :class:`~blocks.graph.ComputationGraph`
218
        A graph of expressions wherein "training mode" batch normalization
219
        is taking place.
220
    allow_duplicates : bool, optional
221
        If `True`, allow multiple training-mode application calls from the
222
        same :class:`~blocks.bricks.BatchNormalization` instance, and
223
        return pairs corresponding to all of them. It's then the user's
224
        responsibility to do something sensible to resolve the duplicates.
225
226
    Returns
227
    -------
228
    update_pairs : list of tuples
229
        A list of 2-tuples where the first element of each tuple is the
230
        shared variable containing a "population" mean or standard
231
        deviation, and the second is a Theano variable for the
232
        corresponding statistics on a minibatch. Note that multiple
233
        applications of a single :class:`blocks.bricks.BatchNormalization`
234
        may appear in the graph, and therefore a single population variable
235
        may map to several different minibatch variables.
236
237
    Notes
238
    -----
239
    Used in their raw form, these updates will simply overwrite the
240
    population statistics with the minibatch statistics at every gradient
241
    step. You will probably want to transform these pairs into something
242
    more sensible, such as keeping a moving average of minibatch values,
243
    or accumulating an average over the entire training set once every few
244
    epochs.
245
246
    """
247
    from ..bricks import BatchNormalization
248
    from ..filter import VariableFilter, get_application_call
249
    var_filter = VariableFilter(bricks=[BatchNormalization], roles=[OUTPUT])
250
    all_app_calls = map(get_application_call, var_filter(training_graph))
251
    train_app_calls = _training_mode_application_calls(all_app_calls)
252
    if len(train_app_calls) == 0:
253
        raise ValueError("no training mode BatchNormalization "
254
                         "applications found in graph")
255
    bricks = [c.application.brick for c in train_app_calls]
256
257
    if not allow_duplicates and not isdistinct(bricks):
258
        raise ValueError('multiple applications of the same '
259
                         'BatchNormalization brick; pass allow_duplicates '
260
                         '= True to override this check')
261
262
    def extract_pair(brick_attribute, metadata_key, app_call):
263
        return (getattr(app_call.application.brick, brick_attribute),
264
                app_call.metadata[metadata_key])
265
266
    mean_pair = partial(extract_pair, 'population_mean', 'offset')
267
    stdev_pair = partial(extract_pair, 'population_stdev', 'divisor')
268
    return sum([[mean_pair(a), stdev_pair(a)] for a in train_app_calls], [])
269