Completed
Pull Request — master (#941)
by David
01:46
created

blocks.graph.get_application_call_dict()   A

Complexity

Conditions 2

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 3
rs 10
1
"""Implements the batch normalization training graph transform.
2
3
Specifically, this module contains the implementation for the transformation
4
of a batch-normalized inference graph into training graph, which uses
5
minibatch statistics in place of population statistics.
6
7
"""
8
import collections
9
from theano import tensor
10
11
from . import add_annotation
12
from ..roles import (BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR,
13
                     BATCH_NORM_POPULATION_STATISTICS,
14
                     BATCH_NORM_MINIBATCH_ESTIMATE, INPUT, add_role,
15
                     has_roles)
16
17
18
def batch_normalize(computation_graph, epsilon=1e-4):
19
    """Activate batch normalization in a graph.
20
21
    Parameters
22
    ----------
23
    computation_graph : instance of :class:`ComputationGraph`
24
        The computation graph containing :class:`BatchNormalization`
25
        brick applications.
26
    epsilon : float, optional
27
        The stabilizing constant for the minibatch standard deviation
28
        computation. Added to the variance inside the square root, as
29
        in the batch normalization paper.
30
31
    Returns
32
    -------
33
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
34
        The computation graph, with :class:`BatchNormalization`
35
        applications transformed to use minibatch statistics instead
36
        of accumulated population statistics.
37
    population_to_minibatch : OrderedDict
38
        A mapping of variables used in the original graph for population
39
        means and standard deviations to the minibatch-derived quantities
40
        that replace them. Useful to define updates in order to track
41
        the approximate population statistics during learning.
42
43
    Notes
44
    -----
45
    Assumes the minibatch axis is 0. Other axes are unsupported at
46
    this time.
47
48
    """
49
    # Avoid a circular import.
50
    from ..filter import VariableFilter, get_application_call
51
52
    # Create filters for variables involved in a batch normalization brick
53
    # application.
54
    def make_variable_filter(role):
55
        from blocks.bricks import BatchNormalization
56
        return VariableFilter(bricks=[BatchNormalization], roles=[role])
57
58
    mean_filter, stdev_filter, input_filter = map(make_variable_filter,
59
                                                  [BATCH_NORM_OFFSET,
60
                                                   BATCH_NORM_DIVISOR, INPUT])
61
62
    # Group means, standard deviations, and inputs into dicts indexed by
63
    # application call.
64
    def get_application_call_dict(variable_filter):
65
        return collections.OrderedDict((get_application_call(v), v) for v in
66
                                       variable_filter(computation_graph))
67
68
    means, stdevs, inputs = map(get_application_call_dict,
69
                                [mean_filter, stdev_filter, input_filter])
70
71
    assert (set(means.keys()) == set(stdevs.keys()) and
72
            set(means.keys()) == set(inputs.keys()))
73
    assert set(means.values()).isdisjoint(stdevs.values())
74
75
    replacements = []
76
    # Perform replacement for each application call.
77
    for application_call in means:
78
        axes = tuple(i for i, b in enumerate(means[application_call]
79
                                             .broadcastable) if b)
80
        minibatch_mean = inputs[application_call].mean(axis=axes,
81
                                                       keepdims=True)
82
        minibatch_mean.name = 'minibatch_offset'
83
        # Stabilize in the same way as the batch normalization manuscript.
84
        minibatch_std = tensor.sqrt(tensor.var(inputs[application_call],
85
                                               axis=axes, keepdims=True) +
86
                                    epsilon)
87
        minibatch_std.name = 'minibatch_divisor'
88
89
        def prepare_replacement(old, new, role, application_call):
90
            """Add roles and tags to replaced variables."""
91
            add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE)
92
            add_role(new, role)
93
            add_annotation(new, application_call)
94
            add_annotation(new, application_call.application.brick)
95
            new.tag.replacement_of = old
96
            replacements.append((old, new))
97
98
        prepare_replacement(means[application_call], minibatch_mean,
99
                            BATCH_NORM_OFFSET, application_call)
100
        prepare_replacement(stdevs[application_call], minibatch_std,
101
                            BATCH_NORM_DIVISOR, application_call)
102
103
    new_graph = computation_graph.replace(replacements)
104
105
    population_to_minibatch = collections.OrderedDict()
106
    for original_graph_node, replacement in replacements:
107
        pop_stats = original_graph_node
108
        while not has_roles(pop_stats, [BATCH_NORM_POPULATION_STATISTICS]):
109
            pop_stats = pop_stats.owner.inputs[0]
110
        # Above, we are replacing a node that has a batch axis added to it
111
        # with a replacement formed via a reduction with keepdims=True. In
112
        # order for the actual shared variable and the replacement to have
113
        # compatible dimensions, we need to drop the leading axis of the
114
        # replacement.
115
        replacement = replacement[0]
116
        assert pop_stats.dtype == replacement.dtype
117
        assert pop_stats.broadcastable == replacement.broadcastable
118
        population_to_minibatch[pop_stats] = replacement
119
    return new_graph, population_to_minibatch
120