Completed
Pull Request — master (#941)
by David
01:47
created

blocks.graph.apply_batch_normalization()   F

Complexity

Conditions 11

Size

Total Lines 80

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 11
dl 0
loc 80
rs 3.1764

2 Methods

Rating   Name   Duplication   Size   Complexity  
A blocks.graph.get_app_call_dict() 0 3 2
A blocks.graph.make_variable_filter() 0 2 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like blocks.graph.apply_batch_normalization() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
11
import theano
12
13
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
14
from ..utils import find_bricks
15
16
17
@contextlib.contextmanager
18
def batch_normalization(*bricks):
19
    r"""Context manager to run batch normalization in "training mode".
20
21
    Parameters
22
    ----------
23
    \*bricks
24
        One or more bricks which will be inspected for descendant
25
        instances of :class:`~blocks.bricks.BatchNormalization`.
26
27
    Notes
28
    -----
29
    Graph replacement using :func:`apply_batch_normalization`, while
30
    elegant, can lead to Theano graphs that are quite large and result
31
    in very slow compiles. This provides an alternative mechanism for
32
    building the batch normalized training graph. It can be somewhat
33
    less convenient as it requires building the graph twice if one
34
    wishes to monitor the output of the inference graph during training.
35
36
    Examples
37
    --------
38
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
39
    This behaves almost exactly like a regular :class:`~blocks.bricks.MLP`
40
    except that it contains :class:`~blocks.bricks.BatchNormalization`
41
    bricks placed before each activation function.
42
43
    >>> import theano
44
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
45
    >>> from blocks.initialization import Constant, IsotropicGaussian
46
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
47
    ...                          weights_init=IsotropicGaussian(0.1),
48
    ...                          biases_init=Constant(0))
49
    >>> mlp.initialize()
50
    >>> x = theano.tensor.matrix('x')
51
52
    First, we'll construct an output variable as we would normally. This
53
    is getting normalized by the *population* statistics, which by
54
    default are initialized to 0 (mean) and 1 (standard deviation),
55
    respectively.
56
57
    >>> y = mlp.apply(x)
58
59
    And now, to construct an output with batch normalization enabled,
60
    i.e. normalizing pre-activations using per-minibatch statistics, we
61
    simply make a similar call inside of a `with` statement:
62
63
    >>> with batch_normalization(mlp):
64
    ...     y_bn = mlp.apply(x)
65
66
    Let's verify that these two graphs behave differently on the
67
    same data:
68
69
    >>> import numpy
70
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
71
    >>> inf_y = y.eval({x: data})
72
    >>> trn_y = y_bn.eval({x: data})
73
    >>> numpy.allclose(inf_y, trn_y)
74
    False
75
76
    """
77
    # Avoid circular imports.
78
    from blocks.bricks import BatchNormalization
79
80
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
81
    # Can't use either nested() (deprecated) nor ExitStack (not available
82
    # on Python 2.7). Well, that sucks.
83
    try:
84
        for brick in bn:
85
            brick.__enter__()
86
        yield
87
    finally:
88
        for brick in bn[::-1]:
89
            brick.__exit__()
0 ignored issues
show
Bug introduced by
The loop variable brick might not be defined here.
Loading history...
90
91
92
def apply_batch_normalization(computation_graph):
93
    """Transform a graph into a batch-normalized training graph.
94
95
    Parameters
96
    ----------
97
    computation_graph : instance of :class:`ComputationGraph`
98
        The computation graph containing :class:`BatchNormalization`
99
        brick applications.
100
101
    Returns
102
    -------
103
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
104
        The computation graph, with :class:`BatchNormalization`
105
        applications transformed to use minibatch statistics instead
106
        of accumulated population statistics.
107
    update_pairs : list of tuples
108
        A list of 2-tuples where the first element of each tuple is the
109
        shared variable containing a "population" mean or standard
110
        deviation, and the second is a Theano variable for the
111
        corresponding statistics on a minibatch. Note that multiple
112
        applications of a single :class:`blocks.bricks.BatchNormalization`
113
        may appear in the graph, and therefore a single population variable
114
        may map to several different minibatch variables.
115
116
    See Also
117
    --------
118
    :func:`batch_normalization`, for an alternative method to produce
119
    batch normalized graphs.
120
121
    """
122
    # Avoid circular imports.
123
    from blocks.bricks import BatchNormalization
124
    from ..filter import VariableFilter, get_application_call
125
126
    # Create filters for variables involved in a batch normalization brick
127
    # application.
128
    def make_variable_filter(role):
129
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
130
131
    # Group inputs and outputs into dicts indexed by application call.
132
    def get_app_call_dict(variable_filter):
133
        return collections.OrderedDict((get_application_call(v), v) for v in
134
                                       variable_filter(computation_graph))
135
136
    # Compose these two so that we get 4 dicts, grouped by application
137
    # call, of different variable roles involved in BatchNormalization.
138
    inputs, outputs, means, stdevs = map(get_app_call_dict,
139
                                         map(make_variable_filter,
140
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
141
                                              BATCH_NORM_DIVISOR]))
142
143
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
144
145
    # Remove any ApplicationCalls that were not generated by apply(), or
146
    # were generated by an apply() while already in training mode.
147
    remove = filter(lambda a: (a.metadata.get('training_mode', False) or
148
                               a.application.application !=
149
                               BatchNormalization.apply), inputs.keys())
150
    for app_call in remove:
151
        for mapping in (inputs, outputs, means, stdevs):
152
            del mapping[app_call]
153
154
    replacements = []
155
    update_pairs = []
156
    for app_call in inputs:
157
        old_output = outputs[app_call]
158
        # Get rid of the copy made on the way into the original apply.
159
        op = inputs[app_call].owner.op
160
        assert (isinstance(op, theano.tensor.Elemwise) and
161
                isinstance(op.scalar_op, theano.scalar.basic.Identity))
162
        unpacked = inputs[app_call].owner.inputs[0]
163
        with app_call.application.brick:
164
            new_output = app_call.application.brick.apply(unpacked)
165
        replacements.append((old_output, new_output))
166
        new_app_call = get_application_call(new_output)
167
        update_pairs.append((app_call.application.brick.population_mean,
168
                             new_app_call.metadata['offset']))
169
        update_pairs.append((app_call.application.brick.population_stdev,
170
                             new_app_call.metadata['divisor']))
171
    return computation_graph.replace(replacements), update_pairs
172