Completed
Pull Request — master (#941)
by David
01:45
created

blocks.graph.get_app_call_dict()   A

Complexity

Conditions 2

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 3
rs 10
1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
11
import theano
12
13
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
14
from ..utils import find_bricks
15
16
17
@contextlib.contextmanager
18
def batch_normalization(*bricks):
19
    r"""Context manager to run batch normalization in "training mode".
20
21
    Parameters
22
    ----------
23
    \*bricks
24
        One or more bricks which will be inspected for descendant
25
        instances of :class:`~blocks.bricks.BatchNormalization`.
26
27
    Notes
28
    -----
29
    Graph replacement using :func:`apply_batch_normalization`, while
30
    elegant, can lead to Theano graphs that are quite large and result
31
    in very slow compiles. This provides an alternative mechanism for
32
    building the batch normalized training graph. It can be somewhat
33
    less convenient as it requires building the graph twice if one
34
    wishes to monitor the output of the inference graph during training.
35
36
    Examples
37
    --------
38
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
39
40
    >>> import theano
41
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
42
    >>> from blocks.initialization import Constant, IsotropicGaussian
43
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
44
    ...                          weights_init=IsotropicGaussian(0.1),
45
    ...                          biases_init=Constant(0))
46
    >>> mlp.initialize()
47
48
    Now, we'll construct an output variable as we would normally. This
49
    is getting normalized by the *population* statistics, which by
50
    default are initialized to 0 (mean) and 1 (standard deviation),
51
    respectively.
52
53
    >>> x = theano.tensor.matrix()
54
    >>> y = mlp.apply(x)
55
56
    And now, to construct an output with batch normalization enabled,
57
    i.e. normalizing pre-activations using per-minibatch statistics, we
58
    simply make a similar call inside of a `with` statement:
59
60
    >>> with batch_normalization(mlp):
61
    ...     y_bn = mlp.apply(x)
62
63
    Let's verify that these two graphs behave differently on the
64
    same data:
65
66
    >>> import numpy
67
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
68
    >>> inf_y = y.eval({x: data})
69
    >>> trn_y = y_bn.eval({x: data})
70
    >>> numpy.allclose(inf_y, trn_y)
71
    False
72
73
    """
74
    # Avoid circular imports.
75
    from blocks.bricks import BatchNormalization
76
77
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
78
    # Can't use either nested() (deprecated) nor ExitStack (not available
79
    # on Python 2.7). Well, that sucks.
80
    try:
81
        for brick in bn:
82
            brick.__enter__()
83
        yield
84
    finally:
85
        for brick in bn[::-1]:
86
            brick.__exit__()
0 ignored issues
show
Bug introduced by
The loop variable brick might not be defined here.
Loading history...
87
88
89
def apply_batch_normalization(computation_graph):
90
    """Transform a graph into a batch-normalized training graph.
91
92
    Parameters
93
    ----------
94
    computation_graph : :class:`~blocks.graph.ComputationGraph`
95
        The computation graph containing :class:`BatchNormalization`
96
        brick applications.
97
98
    Returns
99
    -------
100
    batch_normed_graph : :class:`~blocks.graph.ComputationGraph`
101
        The computation graph, with :class:`BatchNormalization`
102
        applications transformed to use minibatch statistics instead
103
        of accumulated population statistics.
104
    update_pairs : list of tuples
105
        A list of 2-tuples where the first element of each tuple is the
106
        shared variable containing a "population" mean or standard
107
        deviation, and the second is a Theano variable for the
108
        corresponding statistics on a minibatch. Note that multiple
109
        applications of a single :class:`blocks.bricks.BatchNormalization`
110
        may appear in the graph, and therefore a single population variable
111
        may map to several different minibatch variables.
112
113
    See Also
114
    --------
115
    :func:`batch_normalization`, for an alternative method to produce
116
    batch normalized graphs.
117
118
    Examples
119
    --------
120
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
121
122
    >>> import theano
123
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
124
    >>> from blocks.initialization import Constant, IsotropicGaussian
125
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
126
    ...                          weights_init=IsotropicGaussian(0.1),
127
    ...                          biases_init=Constant(0))
128
    >>> mlp.initialize()
129
130
    Now, we'll construct an output variable as we would normally. This
131
    is getting normalized by the *population* statistics, which by
132
    default are initialized to 0 (mean) and 1 (standard deviation),
133
    respectively.
134
135
    >>> x = theano.tensor.matrix()
136
    >>> y = mlp.apply(x)
137
138
    Finally, we'll create a :class:`~blocks.graph.ComputationGraph`
139
    and transform it to switch to minibatch standardization:
140
141
    >>> from blocks.graph import ComputationGraph
142
    >>> cg, _ = apply_batch_normalization(ComputationGraph([y]))
143
    >>> y_bn = cg.outputs[0]
144
145
    Let's verify that these two graphs behave differently on the
146
    same data:
147
148
    >>> import numpy
149
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
150
    >>> inf_y = y.eval({x: data})
151
    >>> trn_y = y_bn.eval({x: data})
152
    >>> numpy.allclose(inf_y, trn_y)
153
    False
154
155
    """
156
    # Avoid circular imports.
157
    from blocks.bricks import BatchNormalization
158
    from ..filter import VariableFilter, get_application_call
159
160
    # Create filters for variables involved in a batch normalization brick
161
    # application.
162
    def make_variable_filter(role):
163
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
164
165
    # Group inputs and outputs into dicts indexed by application call.
166
    def get_app_call_dict(variable_filter):
167
        return collections.OrderedDict((get_application_call(v), v) for v in
168
                                       variable_filter(computation_graph))
169
170
    # Compose these two so that we get 4 dicts, grouped by application
171
    # call, of different variable roles involved in BatchNormalization.
172
    inputs, outputs, means, stdevs = map(get_app_call_dict,
173
                                         map(make_variable_filter,
174
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
175
                                              BATCH_NORM_DIVISOR]))
176
177
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
178
179
    # Remove any ApplicationCalls that were not generated by apply(), or
180
    # were generated by an apply() while already in training mode.
181
    remove = filter(lambda a: (a.metadata.get('training_mode', False) or
182
                               a.application.application !=
183
                               BatchNormalization.apply), inputs.keys())
184
    for app_call in remove:
185
        for mapping in (inputs, outputs, means, stdevs):
186
            del mapping[app_call]
187
188
    replacements = []
189
    update_pairs = []
190
    for app_call in inputs:
191
        old_output = outputs[app_call]
192
        # Get rid of the copy made on the way into the original apply.
193
        op = inputs[app_call].owner.op
194
        assert (isinstance(op, theano.tensor.Elemwise) and
195
                isinstance(op.scalar_op, theano.scalar.basic.Identity))
196
        unpacked = inputs[app_call].owner.inputs[0]
197
        with app_call.application.brick:
198
            new_output = app_call.application.brick.apply(unpacked)
199
        replacements.append((old_output, new_output))
200
        new_app_call = get_application_call(new_output)
201
        update_pairs.append((app_call.application.brick.population_mean,
202
                             new_app_call.metadata['offset']))
203
        update_pairs.append((app_call.application.brick.population_stdev,
204
                             new_app_call.metadata['divisor']))
205
    return computation_graph.replace(replacements), update_pairs
206