Completed
Pull Request — master (#941)
by David
01:49
created

blocks.graph.prepare_replacement()   A

Complexity

Conditions 1

Size

Total Lines 8

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 8
rs 9.4285
1
import collections
2
from theano import tensor
3
4
from . import add_annotation
5
from ..roles import (BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR,
6
                     BATCH_NORM_MINIBATCH_ESTIMATE, INPUT, add_role)
7
8
9
def batch_normalize(computation_graph, epsilon=1e-4):
10
    """Activate batch normalization in a graph.
11
12
    Parameters
13
    ----------
14
    computation_graph : instance of :class:`ComputationGraph`
15
          The computation graph containing :class:`BatchNormalization`
16
          brick applications.
17
    epsilon : float, optional
18
        The stabilizing constant for the minibatch standard deviation
19
        computation. Added to the variance inside the square root, as
20
        in the batch normalization paper.
21
22
    Returns
23
    -------
24
    batch_normed_computation_graph : instance of :class:`ComputationGraph`
25
          The computation graph, with :class:`BatchNormalization`
26
          applications transformed to use minibatch statistics instead
27
          of accumulated population statistics.
28
29
    Notes
30
    -----
31
    Assumes the minibatch axis is 0. Other axes are unsupported at
32
    this time.
33
34
    """
35
36
    # Avoid a circular import.
37
    from ..filter import VariableFilter, get_application_call
38
39
    # Create filters for variables involved in a batch normalization brick
40
    # application.
41
    def make_variable_filter(role):
42
        return VariableFilter(roles=[role])
43
44
    mean_filter, stdev_filter, input_filter = map(make_variable_filter,
45
                                                  [BATCH_NORM_OFFSET,
46
                                                   BATCH_NORM_DIVISOR, INPUT])
47
48
    # Group means, standard deviations, and inputs into dicts indexed by
49
    # application call.
50
    def get_application_call_dict(variable_filter):
51
        return collections.OrderedDict((get_application_call(v), v) for v in
52
                                       variable_filter(computation_graph))
53
54
    means, stdevs, inputs = map(get_application_call_dict,
55
                                [mean_filter, stdev_filter, input_filter])
56
57
    assert (set(means.keys()) == set(stdevs.keys()) and
58
            set(means.keys()) == set(inputs.keys()))
59
    assert set(means.values()).isdisjoint(stdevs.values())
60
61
    replacements = []
62
    # Perform replacement for each application call.
63
    for application_call in means:
64
        axes = tuple(i for i, b in enumerate(means[application_call]
65
                                             .broadcastable) if b)
66
        minibatch_mean = inputs[application_call].mean(axis=axes,
67
                                                       keepdims=True)
68
        minibatch_mean.name = 'minibatch_offset'
69
        # Stabilize in the same way as the batch normalization manuscript.
70
        minibatch_std = tensor.sqrt(tensor.var(inputs[application_call],
71
                                               axis=axes, keepdims=True) +
72
                                    epsilon)
73
        minibatch_std.name = 'minibatch_divisor'
74
75
        def prepare_replacement(old, new, role, application_call):
76
            """Add roles and tags to replaced variables."""
77
            add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE)
78
            add_role(new, role)
79
            add_annotation(new, application_call)
80
            add_annotation(new, application_call.application.brick)
81
            new.tag.replacement_of = old
82
            replacements.append((old, new))
83
84
        prepare_replacement(means[application_call], minibatch_mean,
85
                            BATCH_NORM_OFFSET, application_call)
86
        prepare_replacement(stdevs[application_call], minibatch_std,
87
                            BATCH_NORM_DIVISOR, application_call)
88
89
    return computation_graph.replace(replacements)
90