Completed
Pull Request — master (#941)
by David
01:51
created

blocks.graph.apply_batch_normalization()   F

Complexity

Conditions 10

Size

Total Lines 76

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 10
dl 0
loc 76
rs 3.956

2 Methods

Rating   Name   Duplication   Size   Complexity  
A blocks.graph.get_app_call_dict() 0 3 2
A blocks.graph.make_variable_filter() 0 2 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like blocks.graph.apply_batch_normalization() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
11
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
12
from ..utils import find_bricks
13
14
15
@contextlib.contextmanager
16
def batch_normalization(*bricks):
17
    r"""Context manager to run batch normalization in "training mode".
18
19
    Parameters
20
    ----------
21
    \*bricks
22
        One or more bricks which will be inspected for descendant
23
        instances of :class:`~blocks.bricks.BatchNormalization`.
24
25
    Notes
26
    -----
27
    Graph replacement using :func:`apply_batch_normalization`, while
28
    elegant, can lead to Theano graphs that are quite large and result
29
    in very slow compiles. This provides an alternative mechanism for
30
    building the batch normalized training graph. It can be somewhat
31
    less convenient as it requires building the graph twice if one
32
    wishes to monitor the output of the inference graph during training.
33
34
    Examples
35
    --------
36
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
37
    This behaves almost exactly like a regular :class:`~blocks.bricks.MLP`
38
    except that it contains :class:`~blocks.bricks.BatchNormalization`
39
    bricks placed before each activation function.
40
41
    >>> import theano
42
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
43
    >>> from blocks.initialization import Constant, IsotropicGaussian
44
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
45
    ...                          weights_init=IsotropicGaussian(0.1),
46
    ...                          biases_init=Constant(0))
47
    >>> mlp.initialize()
48
    >>> x = theano.tensor.matrix('x')
49
50
    First, we'll construct an output variable as we would normally. This
51
    is getting normalized by the *population* statistics, which by
52
    default are initialized to 0 (mean) and 1 (standard deviation),
53
    respectively.
54
55
    >>> y = mlp.apply(x)
56
57
    And now, to construct an output with batch normalization enabled,
58
    i.e. normalizing pre-activations using per-minibatch statistics, we
59
    simply make a similar call inside of a `with` statement:
60
61
    >>> with batch_normalization(mlp):
62
    ...     y_bn = mlp.apply(x)
63
64
    Let's verify that these two graphs behave differently on the
65
    same data:
66
67
    >>> import numpy
68
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
69
    >>> inf_y = y.eval({x: data})
70
    >>> trn_y = y_bn.eval({x: data})
71
    >>> numpy.allclose(inf_y, trn_y)
72
    False
73
74
    """
75
    # Avoid circular imports.
76
    from blocks.bricks import BatchNormalization
77
78
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
79
    # Can't use either nested() (deprecated) nor ExitStack (not available
80
    # on Python 2.7). Well, that sucks.
81
    for brick in bn:
82
        brick.__enter__()
83
    yield
84
    for brick in bn:
85
        brick.__exit__()
86
87
88
def apply_batch_normalization(computation_graph):
89
    """Transform a graph into a batch-normalized training graph.
90
91
    Parameters
92
    ----------
93
    computation_graph : instance of :class:`ComputationGraph`
94
        The computation graph containing :class:`BatchNormalization`
95
        brick applications.
96
97
    Returns
98
    -------
99
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
100
        The computation graph, with :class:`BatchNormalization`
101
        applications transformed to use minibatch statistics instead
102
        of accumulated population statistics.
103
    update_pairs : list of tuples
104
        A list of 2-tuples where the first element of each tuple is the
105
        shared variable containing a "population" mean or standard
106
        deviation, and the second is a Theano variable for the
107
        corresponding statistics on a minibatch. Note that multiple
108
        applications of a single :class:`blocks.bricks.BatchNormalization`
109
        may appear in the graph, and therefore a single population variable
110
        may map to several different minibatch variables.
111
112
    See Also
113
    --------
114
    :func:`batch_normalization`, for an alternative method to produce
115
    batch normalized graphs.
116
117
    """
118
    # Avoid circular imports.
119
    from blocks.bricks import BatchNormalization
120
    from ..filter import VariableFilter, get_application_call
121
122
    # Create filters for variables involved in a batch normalization brick
123
    # application.
124
    def make_variable_filter(role):
125
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
126
127
    # Group inputs and outputs into dicts indexed by application call.
128
    def get_app_call_dict(variable_filter):
129
        return collections.OrderedDict((get_application_call(v), v) for v in
130
                                       variable_filter(computation_graph))
131
132
    # Compose these two so that we get 4 dicts, grouped by application
133
    # call, of different variable roles involved in BatchNormalization.
134
    inputs, outputs, means, stdevs = map(get_app_call_dict,
135
                                         map(make_variable_filter,
136
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
137
                                              BATCH_NORM_DIVISOR]))
138
139
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
140
141
    # Remove any ApplicationCalls that were not generated by apply(), or
142
    # were generated by an apply() while already in training mode.
143
    remove = filter(lambda a: (a.metadata.get('training_mode', False) or
144
                               a.application.application !=
145
                               BatchNormalization.apply), inputs.keys())
146
    for app_call in remove:
147
        for mapping in (inputs, outputs, means, stdevs):
148
            del mapping[app_call]
149
150
    replacements = []
151
    update_pairs = []
152
    for app_call in inputs:
153
        old_output = outputs[app_call]
154
        unpacked = inputs[app_call].owner.inputs[0]
155
        with app_call.application.brick:
156
            new_output = app_call.application.brick.apply(unpacked)
157
        replacements.append((old_output, new_output))
158
        new_app_call = get_application_call(new_output)
159
        update_pairs.append((app_call.application.brick.population_mean,
160
                             new_app_call.metadata['offset']))
161
        update_pairs.append((app_call.application.brick.population_stdev,
162
                             new_app_call.metadata['divisor']))
163
    return computation_graph.replace(replacements), update_pairs
164