Completed
Push — master ( c8682f...5bbe2a )
by Raphael
01:33
created

deepy.layers.Chain   A

Complexity

Total Complexity 15

Size/Duplication

Total Lines 57
Duplicated Lines 0 %
Metric Value
dl 0
loc 57
rs 10
wmc 15

7 Methods

Rating   Name   Duplication   Size   Complexity  
A __init__() 0 9 3
A stack() 0 7 3
A output() 0 2 1
A _register_layers() 0 9 3
A test_output() 0 2 1
A prepare() 0 4 2
A _output() 0 5 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from layer import NeuralLayer
5
6
7
class Chain(NeuralLayer):
8
    """
9
    Stack many layers to form a chain.
10
    This is useful to reuse layers in a customized layer.
11
    Usage:
12
        As part of the main pipe line:
13
            chain = Chain(layer1, layer2)
14
            model.stack(chain)
15
        As part of the computational graph:
16
            chain = Chain(layer1, layer2)
17
            y = chain.compute(x)
18
    """
19
20
    def __init__(self, *layers):
21
        super(Chain, self).__init__("chain")
22
        self.layers = []
23
        self._layers_to_stack = []
24
        if len(layers) == 1 and type(layers[0]) == int:
25
            # This is a deprecated using of Chain
26
            self.input_dim = layers[0]
27
        else:
28
            self.stack(*layers)
29
30
    def stack(self, *layers):
31
        if self.input_dim is None or self.input_dim == 0:
32
            # Don't know the input dimension until connect
33
            self._layers_to_stack.extend(layers)
34
        else:
35
            self._register_layers(*layers)
36
        return self
37
38
    def _register_layers(self, *layers):
39
        for layer in layers:
40
            if not self.layers:
41
                layer.connect(self.input_dim)
42
            else:
43
                layer.connect(self.layers[-1].output_dim)
44
            self.layers.append(layer)
45
            self.output_dim = layer.output_dim
46
        self.register_inner_layers(*self.layers)
47
48
    def prepare(self, *layers):
49
        if self._layers_to_stack:
50
            self._register_layers(*self._layers_to_stack)
51
            self._layers_to_stack = []
52
53
    def output(self, x):
54
        return self._output(x, False)
55
56
    def test_output(self, x):
57
        return self._output(x, True)
58
59
    def _output(self, x, test):
60
        y = x
61
        for layer in self.layers:
62
            y = layer.call(y, test=test)
63
        return y
64