Completed
Pull Request — master (#941)
by David
02:31
created

blocks.graph.apply_batch_normalization()   F

Complexity

Conditions 10

Size

Total Lines 76

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 10
dl 0
loc 76
rs 3.956

2 Methods

Rating   Name   Duplication   Size   Complexity  
A blocks.graph.get_app_call_dict() 0 3 2
A blocks.graph.make_variable_filter() 0 2 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like blocks.graph.apply_batch_normalization() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the
4
transformation of a batch-normalized inference graph into training graph,
5
which uses minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
import contextlib
10
11
from ..roles import BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, INPUT, OUTPUT
12
from ..utils import find_bricks
13
14
15
@contextlib.contextmanager
16
def batch_normalization(*bricks):
17
    r"""Context manager to run batch normalization in "training mode".
18
19
    Parameters
20
    ----------
21
    \*bricks
22
        One or more bricks which will be inspected for descendant
23
        instances of :class:`~blocks.bricks.BatchNormalization`.
24
25
    Notes
26
    -----
27
    Graph replacement using :func:`apply_batch_normalization`, while
28
    elegant, can lead to Theano graphs that are quite large and result
29
    in very slow compiles. This provides an alternative mechanism for
30
    building the batch normalized training graph. It can be somewhat
31
    less convenient as it requires building the graph twice if one
32
    wishes to monitor the output of the inference graph during training.
33
34
    Examples
35
    --------
36
    First, we'll create a :class:`~blocks.bricks.BatchNormalizedMLP`.
37
    This behaves almost exactly like a regular :class:`~blocks.bricks.MLP`
38
    except that it contains :class:`~blocks.bricks.BatchNormalization`
39
    bricks placed before each activation function.
40
41
    >>> import theano
42
    >>> from blocks.bricks import BatchNormalizedMLP, Tanh
43
    >>> from blocks.initialization import Constant, IsotropicGaussian
44
    >>> mlp = BatchNormalizedMLP([Tanh(), Tanh()], [4, 5, 6],
45
    ...                          weights_init=IsotropicGaussian(0.1),
46
    ...                          biases_init=Constant(0))
47
    >>> mlp.initialize()
48
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
49
    >>> x = theano.tensor.matrix('x')
50
51
    First, we'll construct an output variable as we would normally. This
52
    is getting normalized by the *population* statistics, which by
53
    default are initialized to 0 (mean) and 1 (standard deviation),
54
    respectively.
55
56
    >>> y = mlp.apply(x)
57
58
    And now, to construct an output with batch normalization enabled,
59
    i.e. normalizing pre-activations using per-minibatch statistics, we
60
    simply make a similar call inside of a `with` statement:
61
62
    >>> with batch_normalization(mlp):
63
    ...     y_bn = mlp.apply(x)
64
65
    Let's verify that these two graphs behave differently on the
66
    same data:
67
68
    >>> import numpy
69
    >>> data = numpy.arange(12, dtype=theano.config.floatX).reshape(3, 4)
70
    >>> inf_y = y.eval({x: data})
71
    >>> trn_y = y_bn.eval({x: data})
72
    >>> numpy.allclose(inf_y, trn_y)
73
    False
74
75
    """
76
    # Avoid circular imports.
77
    from blocks.bricks import BatchNormalization
78
79
    bn = find_bricks(bricks, lambda b: isinstance(b, BatchNormalization))
80
    # Can't use either nested() (deprecated) nor ExitStack (not available
81
    # on Python 2.7). Well, that sucks.
82
    for brick in bn:
83
        brick.__enter__()
84
    yield
85
    for brick in bn:
86
        brick.__exit__()
87
88
89
def apply_batch_normalization(computation_graph):
90
    """Transform a graph into a batch-normalized training graph.
91
92
    Parameters
93
    ----------
94
    computation_graph : instance of :class:`ComputationGraph`
95
        The computation graph containing :class:`BatchNormalization`
96
        brick applications.
97
98
    Returns
99
    -------
100
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
101
        The computation graph, with :class:`BatchNormalization`
102
        applications transformed to use minibatch statistics instead
103
        of accumulated population statistics.
104
    update_pairs : list of tuples
105
        A list of 2-tuples where the first element of each tuple is the
106
        shared variable containing a "population" mean or standard
107
        deviation, and the second is a Theano variable for the
108
        corresponding statistics on a minibatch. Note that multiple
109
        applications of a single :class:`blocks.bricks.BatchNormalization`
110
        may appear in the graph, and therefore a single population variable
111
        may map to several different minibatch variables.
112
113
    See Also
114
    --------
115
    :func:`batch_normalization`, for an alternative method to produce
116
    batch normalized graphs.
117
118
    """
119
    # Avoid circular imports.
120
    from blocks.bricks import BatchNormalization
121
    from ..filter import VariableFilter, get_application_call
122
123
    # Create filters for variables involved in a batch normalization brick
124
    # application.
125
    def make_variable_filter(role):
126
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
127
128
    # Group inputs and outputs into dicts indexed by application call.
129
    def get_app_call_dict(variable_filter):
130
        return collections.OrderedDict((get_application_call(v), v) for v in
131
                                       variable_filter(computation_graph))
132
133
    # Compose these two so that we get 4 dicts, grouped by application
134
    # call, of different variable roles involved in BatchNormalization.
135
    inputs, outputs, means, stdevs = map(get_app_call_dict,
136
                                         map(make_variable_filter,
137
                                             [INPUT, OUTPUT, BATCH_NORM_OFFSET,
138
                                              BATCH_NORM_DIVISOR]))
139
140
    assert len(set([len(inputs), len(outputs), len(means), len(stdevs)])) == 1
141
142
    # Remove any ApplicationCalls that were not generated by apply(), or
143
    # were generated by an apply() while already in training mode.
144
    remove = filter(lambda a: (a.metadata.get('training_mode', False) or
145
                               a.application.application !=
146
                               BatchNormalization.apply), inputs.keys())
147
    for app_call in remove:
148
        for mapping in (inputs, outputs, means, stdevs):
149
            del mapping[app_call]
150
151
    replacements = []
152
    update_pairs = []
153
    for app_call in inputs:
154
        old_output = outputs[app_call]
155
        unpacked = inputs[app_call].owner.inputs[0]
156
        with app_call.application.brick:
157
            new_output = app_call.application.brick.apply(unpacked)
158
        replacements.append((old_output, new_output))
159
        new_app_call = get_application_call(new_output)
160
        update_pairs.append((app_call.application.brick.population_mean,
161
                             new_app_call.metadata['offset']))
162
        update_pairs.append((app_call.application.brick.population_stdev,
163
                             new_app_call.metadata['divisor']))
164
    return computation_graph.replace(replacements), update_pairs
165