Completed
Pull Request — master (#941)
by David
01:46
created

blocks.graph.batch_normalize()   F

Complexity

Conditions 14

Size

Total Lines 102

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 14
dl 0
loc 102
rs 2

3 Methods

Rating   Name   Duplication   Size   Complexity  
A blocks.graph.prepare_replacement() 0 8 1
A blocks.graph.get_application_call_dict() 0 3 2
A blocks.graph.make_variable_filter() 0 3 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like blocks.graph.batch_normalize() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import collections
2
from theano import tensor
3
4
from . import add_annotation
5
from ..roles import (BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR,
6
                     BATCH_NORM_POPULATION_STATISTICS,
7
                     BATCH_NORM_MINIBATCH_ESTIMATE, INPUT, add_role,
8
                     has_roles)
9
10
11
def batch_normalize(computation_graph, epsilon=1e-4):
12
    """Activate batch normalization in a graph.
13
14
    Parameters
15
    ----------
16
    computation_graph : instance of :class:`ComputationGraph`
17
        The computation graph containing :class:`BatchNormalization`
18
        brick applications.
19
    epsilon : float, optional
20
        The stabilizing constant for the minibatch standard deviation
21
        computation. Added to the variance inside the square root, as
22
        in the batch normalization paper.
23
24
    Returns
25
    -------
26
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
27
        The computation graph, with :class:`BatchNormalization`
28
        applications transformed to use minibatch statistics instead
29
        of accumulated population statistics.
30
    population_to_minibatch : OrderedDict
31
        A mapping of variables used in the original graph for population
32
        means and standard deviations to the minibatch-derived quantities
33
        that replace them. Useful to define updates in order to track
34
        the approximate population statistics during learning.
35
36
    Notes
37
    -----
38
    Assumes the minibatch axis is 0. Other axes are unsupported at
39
    this time.
40
41
    """
42
    # Avoid a circular import.
43
    from ..filter import VariableFilter, get_application_call
44
45
    # Create filters for variables involved in a batch normalization brick
46
    # application.
47
    def make_variable_filter(role):
48
        from blocks.bricks import BatchNormalization
49
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
50
51
    mean_filter, stdev_filter, input_filter = map(make_variable_filter,
52
                                                  [BATCH_NORM_OFFSET,
53
                                                   BATCH_NORM_DIVISOR, INPUT])
54
55
    # Group means, standard deviations, and inputs into dicts indexed by
56
    # application call.
57
    def get_application_call_dict(variable_filter):
58
        return collections.OrderedDict((get_application_call(v), v) for v in
59
                                       variable_filter(computation_graph))
60
61
    means, stdevs, inputs = map(get_application_call_dict,
62
                                [mean_filter, stdev_filter, input_filter])
63
64
    assert (set(means.keys()) == set(stdevs.keys()) and
65
            set(means.keys()) == set(inputs.keys()))
66
    assert set(means.values()).isdisjoint(stdevs.values())
67
68
    replacements = []
69
    # Perform replacement for each application call.
70
    for application_call in means:
71
        axes = tuple(i for i, b in enumerate(means[application_call]
72
                                             .broadcastable) if b)
73
        minibatch_mean = inputs[application_call].mean(axis=axes,
74
                                                       keepdims=True)
75
        minibatch_mean.name = 'minibatch_offset'
76
        # Stabilize in the same way as the batch normalization manuscript.
77
        minibatch_std = tensor.sqrt(tensor.var(inputs[application_call],
78
                                               axis=axes, keepdims=True) +
79
                                    epsilon)
80
        minibatch_std.name = 'minibatch_divisor'
81
82
        def prepare_replacement(old, new, role, application_call):
83
            """Add roles and tags to replaced variables."""
84
            add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE)
85
            add_role(new, role)
86
            add_annotation(new, application_call)
87
            add_annotation(new, application_call.application.brick)
88
            new.tag.replacement_of = old
89
            replacements.append((old, new))
90
91
        prepare_replacement(means[application_call], minibatch_mean,
92
                            BATCH_NORM_OFFSET, application_call)
93
        prepare_replacement(stdevs[application_call], minibatch_std,
94
                            BATCH_NORM_DIVISOR, application_call)
95
96
    new_graph = computation_graph.replace(replacements)
97
98
    population_to_minibatch = collections.OrderedDict()
99
    for original_graph_node, replacement in replacements:
100
        pop_stats = original_graph_node
101
        while not has_roles(pop_stats, [BATCH_NORM_POPULATION_STATISTICS]):
102
            pop_stats = pop_stats.owner.inputs[0]
103
        # Above, we are replacing a node that has a batch axis added to it
104
        # with a replacement formed via a reduction with keepdims=True. In
105
        # order for the actual shared variable and the replacement to have
106
        # compatible dimensions, we need to drop the leading axis of the
107
        # replacement.
108
        replacement = replacement[0]
109
        assert pop_stats.dtype == replacement.dtype
110
        assert pop_stats.broadcastable == replacement.broadcastable
111
        population_to_minibatch[pop_stats] = replacement
112
    return new_graph, population_to_minibatch
113