Completed
Push — master ( c8682f...5bbe2a )
by Raphael
01:33
created

deepy.layers.Chain.setup()   A

Complexity

Conditions 2

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from layer import NeuralLayer
5
6
7
class Chain(NeuralLayer):
8
    """
9
    Stack many layers to form a chain.
10
    This is useful to reuse layers in a customized layer.
11
    Usage:
12
        As part of the main pipe line:
13
            chain = Chain(layer1, layer2)
14
            model.stack(chain)
15
        As part of the computational graph:
16
            chain = Chain(layer1, layer2)
17
            y = chain.compute(x)
18
    """
19
20
    def __init__(self, *layers):
21
        super(Chain, self).__init__("chain")
22
        self.layers = []
23
        self._layers_to_stack = []
24
        if len(layers) == 1 and type(layers[0]) == int:
25
            # This is a deprecated using of Chain
26
            self.input_dim = layers[0]
27
        else:
28
            self.stack(*layers)
29
30
    def stack(self, *layers):
31
        if self.input_dim is None or self.input_dim == 0:
32
            # Don't know the input dimension until connect
33
            self._layers_to_stack.extend(layers)
34
        else:
35
            self._register_layers(*layers)
36
        return self
37
38
    def _register_layers(self, *layers):
39
        for layer in layers:
40
            if not self.layers:
41
                layer.connect(self.input_dim)
42
            else:
43
                layer.connect(self.layers[-1].output_dim)
44
            self.layers.append(layer)
45
            self.output_dim = layer.output_dim
46
        self.register_inner_layers(*self.layers)
47
48
    def prepare(self, *layers):
49
        if self._layers_to_stack:
50
            self._register_layers(*self._layers_to_stack)
51
            self._layers_to_stack = []
52
53
    def output(self, x):
54
        return self._output(x, False)
55
56
    def test_output(self, x):
57
        return self._output(x, True)
58
59
    def _output(self, x, test):
60
        y = x
61
        for layer in self.layers:
62
            y = layer.call(y, test=test)
63
        return y
64