Completed
Pull Request — master (#941)
by David
03:00 queued 01:28
created

blocks.graph.get_app_call_dict()   A

Complexity

Conditions 2

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 3
rs 10
1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
11
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
12
from ..utils import find_bricks
13
14
15
@contextlib.contextmanager
16
def batch_normalization(*bricks):
17
    r"""Context manager to run batch normalization in "training mode".
18
19
    Parameters
20
    ----------
21
    \*bricks
22
        One or more bricks which will be inspected for descendant
23
        instances of :class:`~blocks.bricks.BatchNormalization`.
24
25
    Notes
26
    -----
27
    Graph replacement using :func:`apply_batch_normalization`, while
28
    elegant, can lead to Theano graphs that are quite large and result
29
    in very slow compiles. This provides an alternative mechanism for
30
    building the batch normalized training graph. It can be somewhat
31
    less convenient as it requires building the graph twice if one
32
    wishes to monitor the output of the inference graph during training.
33
34
    Examples
35
    --------
36
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
37
    This behaves almost exactly like a regular :class:`~blocks.bricks.MLP`
38
    except that it contains :class:`~blocks.bricks.BatchNormalization`
39
    bricks placed before each activation function.
40
41
    >>> import theano
42
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
43
    >>> from blocks.initialization import Constant, IsotropicGaussian
44
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
45
    ...                          weights_init=IsotropicGaussian(0.1),
46
    ...                          biases_init=Constant(0))
47
    >>> mlp.initialize()
48
    >>> x = theano.tensor.matrix('x')
49
50
    First, we'll construct an output variable as we would normally. This
51
    is getting normalized by the *population* statistics, which by
52
    default are initialized to 0 (mean) and 1 (standard deviation),
53
    respectively.
54
55
    >>> y = mlp.apply(x)
56
57
    And now, to construct an output with batch normalization enabled,
58
    i.e. normalizing pre-activations using per-minibatch statistics, we
59
    simply make a similar call inside of a `with` statement:
60
61
    >>> with batch_normalization(mlp):
62
    ...     y_bn = mlp.apply(x)
63
64
    Let's verify that these two graphs behave differently on the
65
    same data:
66
67
    >>> import numpy
68
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
69
    >>> inf_y = y.eval({x: data})
70
    >>> trn_y = y_bn.eval({x: data})
71
    >>> numpy.allclose(inf_y, trn_y)
72
    False
73
74
    """
75
    # Avoid circular imports.
76
    from blocks.bricks import BatchNormalization
77
78
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
79
    # Can't use either nested() (deprecated) nor ExitStack (not available
80
    # on Python 2.7). Well, that sucks.
81
    try:
82
        for brick in bn:
83
            brick.__enter__()
84
        yield
85
    finally:
86
        for brick in bn[::-1]:
87
            brick.__exit__()
0 ignored issues
show
Bug introduced by
The loop variable brick might not be defined here.
Loading history...
88
89
90
def apply_batch_normalization(computation_graph):
91
    """Transform a graph into a batch-normalized training graph.
92
93
    Parameters
94
    ----------
95
    computation_graph : instance of :class:`ComputationGraph`
96
        The computation graph containing :class:`BatchNormalization`
97
        brick applications.
98
99
    Returns
100
    -------
101
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
102
        The computation graph, with :class:`BatchNormalization`
103
        applications transformed to use minibatch statistics instead
104
        of accumulated population statistics.
105
    update_pairs : list of tuples
106
        A list of 2-tuples where the first element of each tuple is the
107
        shared variable containing a "population" mean or standard
108
        deviation, and the second is a Theano variable for the
109
        corresponding statistics on a minibatch. Note that multiple
110
        applications of a single :class:`blocks.bricks.BatchNormalization`
111
        may appear in the graph, and therefore a single population variable
112
        may map to several different minibatch variables.
113
114
    See Also
115
    --------
116
    :func:`batch_normalization`, for an alternative method to produce
117
    batch normalized graphs.
118
119
    """
120
    # Avoid circular imports.
121
    from blocks.bricks import BatchNormalization
122
    from ..filter import VariableFilter, get_application_call
123
124
    # Create filters for variables involved in a batch normalization brick
125
    # application.
126
    def make_variable_filter(role):
127
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
128
129
    # Group inputs and outputs into dicts indexed by application call.
130
    def get_app_call_dict(variable_filter):
131
        return collections.OrderedDict((get_application_call(v), v) for v in
132
                                       variable_filter(computation_graph))
133
134
    # Compose these two so that we get 4 dicts, grouped by application
135
    # call, of different variable roles involved in BatchNormalization.
136
    inputs, outputs, means, stdevs = map(get_app_call_dict,
137
                                         map(make_variable_filter,
138
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
139
                                              BATCH_NORM_DIVISOR]))
140
141
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
142
143
    # Remove any ApplicationCalls that were not generated by apply(), or
144
    # were generated by an apply() while already in training mode.
145
    remove = filter(lambda a: (a.metadata.get('training_mode', False) or
146
                               a.application.application !=
147
                               BatchNormalization.apply), inputs.keys())
148
    for app_call in remove:
149
        for mapping in (inputs, outputs, means, stdevs):
150
            del mapping[app_call]
151
152
    replacements = []
153
    update_pairs = []
154
    for app_call in inputs:
155
        old_output = outputs[app_call]
156
        unpacked = inputs[app_call].owner.inputs[0]
157
        with app_call.application.brick:
158
            new_output = app_call.application.brick.apply(unpacked)
159
        replacements.append((old_output, new_output))
160
        new_app_call = get_application_call(new_output)
161
        update_pairs.append((app_call.application.brick.population_mean,
162
                             new_app_call.metadata['offset']))
163
        update_pairs.append((app_call.application.brick.population_stdev,
164
                             new_app_call.metadata['divisor']))
165
    return computation_graph.replace(replacements), update_pairs
166