Total Complexity | 3 |
Total Lines | 21 |
Duplicated Lines | 0 % |
1 | #!/usr/bin/env python |
||
8 | class BatchNormalization(NeuralLayer): |
||
9 | """ |
||
10 | Batch normalization. |
||
11 | http://arxiv.org/pdf/1502.03167v3.pdf |
||
12 | """ |
||
13 | def __init__(self, epsilon=1e-6, weights=None): |
||
14 | super(BatchNormalization,self).__init__("norm") |
||
15 | self.epsilon = epsilon |
||
16 | |||
17 | def prepare(self): |
||
18 | self.gamma = self.create_weight(shape=(self.input_dim,), suffix="gamma") |
||
19 | self.beta = self.create_bias(self.input_dim, suffix="beta") |
||
20 | self.register_parameters(self.gamma, self.beta) |
||
21 | |||
22 | def compute_tensor(self, x): |
||
23 | |||
24 | m = x.mean(axis=0) |
||
25 | std = T.mean((x-m)**2 + self.epsilon, axis=0) ** 0.5 |
||
26 | x_normed = (x - m) / (std + self.epsilon) |
||
27 | out = self.gamma * x_normed + self.beta |
||
28 | return out |