Completed
Push — master ( 48255b...bf2b0c )
by Raphael
01:13
created

Block.load_params()   A

Complexity

Conditions 1

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 7
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from layer import NeuralLayer
5
6
7
class Block(NeuralLayer):
8
    """
9
    Create a block, which contains the parameters of many connected layers.
10
    """
11
12
    def __init__(self):
13
        super(Block, self).__init__("block")
14
        self.layers = []
15
        self.fixed = False
16
17
    def fix(self):
18
        """
19
        Fix the block, register all the parameters of sub layers.
20
        :return:
21
        """
22
        if not self.fixed:
23
            for layer in self.layers:
24
                if not layer.initialized:
25
                    raise Exception("All sub layers in a block must be initialized when fixing it.")
26
                self.register_inner_layers(layer)
27
            self.fixed = True
28
29
30
    def register(self, *layers):
31
        """
32
        Register many connected layers.
33
        :type layers: list of NeuralLayer
34
        """
35
        for layer in layers:
36
            self.register_layer(layer)
37
38
    def register_layer(self, layer):
39
        """
40
        Register one connected layer.
41
        :type layer: NeuralLayer
42
        """
43
        if self.fixed:
44
            raise Exception("After a block is fixed, no more layers can be registered.")
45
        self.layers.append(layer)
46
47
48
    def compute_tensor(self, x):
49
        return x
50
51
    def compute_test_tesnor(self, x):
52
        return x
53
54
    def load_params(self, path, exclude_free_params=False):
55
        """
56
        Load parameters to the block.
57
        """
58
        from deepy.networks.comp_graph import ComputationalGraph
59
        model = ComputationalGraph(blocks=[self])
60
        model.load_params(path, exclude_free_params=exclude_free_params)
61
62
    @property
63
    def all_parameters(self):
64
        return self.parameters