|
1
|
|
|
import collections |
|
2
|
|
|
from theano import tensor |
|
3
|
|
|
|
|
4
|
|
|
from . import add_annotation |
|
5
|
|
|
from ..roles import (BATCH_NORM_OFFSET, BATCH_NORM_DIVISOR, |
|
6
|
|
|
BATCH_NORM_MINIBATCH_ESTIMATE, INPUT, add_role) |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
def batch_normalize(computation_graph, epsilon=1e-4): |
|
10
|
|
|
"""Activate batch normalization in a graph. |
|
11
|
|
|
|
|
12
|
|
|
Parameters |
|
13
|
|
|
---------- |
|
14
|
|
|
computation_graph : instance of :class:`ComputationGraph` |
|
15
|
|
|
The computation graph containing :class:`BatchNormalization` |
|
16
|
|
|
brick applications. |
|
17
|
|
|
epsilon : float, optional |
|
18
|
|
|
The stabilizing constant for the minibatch standard deviation |
|
19
|
|
|
computation. Added to the variance inside the square root, as |
|
20
|
|
|
in the batch normalization paper. |
|
21
|
|
|
|
|
22
|
|
|
Returns |
|
23
|
|
|
------- |
|
24
|
|
|
batch_normed_computation_graph : instance of :class:`ComputationGraph` |
|
25
|
|
|
The computation graph, with :class:`BatchNormalization` |
|
26
|
|
|
applications transformed to use minibatch statistics instead |
|
27
|
|
|
of accumulated population statistics. |
|
28
|
|
|
|
|
29
|
|
|
Notes |
|
30
|
|
|
----- |
|
31
|
|
|
Assumes the minibatch axis is 0. Other axes are unsupported at |
|
32
|
|
|
this time. |
|
33
|
|
|
|
|
34
|
|
|
""" |
|
35
|
|
|
|
|
36
|
|
|
# Avoid a circular import. |
|
37
|
|
|
from ..filter import VariableFilter, get_application_call |
|
38
|
|
|
|
|
39
|
|
|
# Create filters for variables involved in a batch normalization brick |
|
40
|
|
|
# application. |
|
41
|
|
|
def make_variable_filter(role): |
|
42
|
|
|
return VariableFilter(roles=[role]) |
|
43
|
|
|
|
|
44
|
|
|
mean_filter, stdev_filter, input_filter = map(make_variable_filter, |
|
45
|
|
|
[BATCH_NORM_OFFSET, |
|
46
|
|
|
BATCH_NORM_DIVISOR, INPUT]) |
|
47
|
|
|
|
|
48
|
|
|
# Group means, standard deviations, and inputs into dicts indexed by |
|
49
|
|
|
# application call. |
|
50
|
|
|
def get_application_call_dict(variable_filter): |
|
51
|
|
|
return collections.OrderedDict((get_application_call(v), v) for v in |
|
52
|
|
|
variable_filter(computation_graph)) |
|
53
|
|
|
|
|
54
|
|
|
means, stdevs, inputs = map(get_application_call_dict, |
|
55
|
|
|
[mean_filter, stdev_filter, input_filter]) |
|
56
|
|
|
|
|
57
|
|
|
assert (set(means.keys()) == set(stdevs.keys()) and |
|
58
|
|
|
set(means.keys()) == set(inputs.keys())) |
|
59
|
|
|
assert set(means.values()).isdisjoint(stdevs.values()) |
|
60
|
|
|
|
|
61
|
|
|
replacements = [] |
|
62
|
|
|
# Perform replacement for each application call. |
|
63
|
|
|
for application_call in means: |
|
64
|
|
|
axes = tuple(i for i, b in enumerate(means[application_call] |
|
65
|
|
|
.broadcastable) if b) |
|
66
|
|
|
minibatch_mean = inputs[application_call].mean(axis=axes, |
|
67
|
|
|
keepdims=True) |
|
68
|
|
|
minibatch_mean.name = 'minibatch_offset' |
|
69
|
|
|
# Stabilize in the same way as the batch normalization manuscript. |
|
70
|
|
|
minibatch_std = tensor.sqrt(tensor.var(inputs[application_call], |
|
71
|
|
|
axis=axes, keepdims=True) + |
|
72
|
|
|
epsilon) |
|
73
|
|
|
minibatch_std.name = 'minibatch_divisor' |
|
74
|
|
|
|
|
75
|
|
|
def prepare_replacement(old, new, role, application_call): |
|
76
|
|
|
"""Add roles and tags to replaced variables.""" |
|
77
|
|
|
add_role(new, BATCH_NORM_MINIBATCH_ESTIMATE) |
|
78
|
|
|
add_role(new, role) |
|
79
|
|
|
add_annotation(new, application_call) |
|
80
|
|
|
add_annotation(new, application_call.application.brick) |
|
81
|
|
|
new.tag.replacement_of = old |
|
82
|
|
|
replacements.append((old, new)) |
|
83
|
|
|
|
|
84
|
|
|
prepare_replacement(means[application_call], minibatch_mean, |
|
85
|
|
|
BATCH_NORM_OFFSET, application_call) |
|
86
|
|
|
prepare_replacement(stdevs[application_call], minibatch_std, |
|
87
|
|
|
BATCH_NORM_DIVISOR, application_call) |
|
88
|
|
|
|
|
89
|
|
|
return computation_graph.replace(replacements) |
|
90
|
|
|
|