Completed
Push — master ( c8682f...5bbe2a )
by Raphael
01:33
created

prepare()   A

Complexity

Conditions 2

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 16
rs 9.4286
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from deepy import *
5
import theano.tensor as T
6
7
class AggregationLayer(NeuralLayer):
8
    """
9
    Aggregation layer.
10
    """
11
12
    def __init__(self, size, activation='relu', init=None, layers=3):
13
        super(AggregationLayer, self).__init__("aggregation")
14
        self.size = size
15
        self.activation = activation
16
        self.init = init
17
        self.layers = layers
18
19
    def prepare(self):
20
        self.output_dim = self.size
21
        self._act = build_activation(self.activation)
22
        self._inner_layers = [Dense(self.size, self.activation, init=self.init).connect(self.input_dim)]
23
        for _ in range(self.layers - 1):
24
            self._inner_layers.append(Dense(self.size, self.activation, init=self.init).connect(self.size))
25
        self.register_inner_layers(*self._inner_layers)
26
27
        self._chain2 = Chain(self.input_dim).stack(
28
            Dense(self.size, self.activation, init=self.init),
29
            Dense(self.layers, 'linear', init=self.init),
30
            Softmax()
31
        )
32
33
        self.register_inner_layers(self._chain2)
34
        self._dropout = Dropout(0.1)
35
36
    def _output(self, x, test=False):
37
        seq = []
38
        v = x
39
        for layer in self._inner_layers:
40
            v = layer.call(v, test)
41
            v = self._dropout.call(v, test)
42
            seq.append(v.dimshuffle(0, "x", 1))
43
44
        seq_v = T.concatenate(seq, axis=1)
45
46
        eva = self._chain2.call(x, test)
47
48
        result = seq_v * eva.dimshuffle((0, 1, "x"))
49
        result = result.sum(axis=1)
50
        return result
51
52
    def output(self, x):
53
        return self._output(x, False)
54
55
    def test_output(self, x):
56
        return self._output(x, True)
57
58
59
60