deepy.networks.AutoEncoder.test_cost()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from deepy.utils import AutoEncoderCost
5
import theano.tensor as T
6
7
from network import NeuralNetwork
8
9
class AutoEncoder(NeuralNetwork):
10
    """
11
    Auto encoder.
12
    Must call stack_encoding before stack_decoding.
13
    Parameters:
14
        rep_dim - dimension of representation
15
    """
16
    def __init__(self, input_dim, rep_dim=None, config=None, input_tensor=None):
17
        super(AutoEncoder, self).__init__(input_dim, config=config, input_tensor=input_tensor)
18
19
        self.rep_dim = rep_dim
20
        self.encoding_layes = []
21
        self.decoding_layers = []
22
        self.encoding_network = None
23
        self.decoding_network = None
24
25
    def _cost_func(self, y):
26
        return AutoEncoderCost(self.input_variables[0], y).get()
27
28
    @property
29
    def cost(self):
30
        return self._cost_func(self.output)
31
32
    @property
33
    def test_cost(self):
34
        return self._cost_func(self.test_output)
35
36
    def stack_encoders(self, *layers):
37
        """
38
        Stack encoding layers, this must be done before stacking decoding layers.
39
        """
40
        self.stack(*layers)
41
        self.encoding_layes.extend(layers)
42
43
    def stack_decoders(self, *layers):
44
        """
45
        Stack decoding layers.
46
        """
47
        self.stack(*layers)
48
        self.decoding_layers.extend(layers)
49
50
    def encode(self, x):
51
        """
52
        Encode given input.
53
        """
54
        if not self.encoding_network:
55
            self.encoding_network = NeuralNetwork(self.input_dim, self.network_config, self.input_tensor)
56
            for layer in self.encoding_layes:
57
                self.encoding_network.stack_layer(layer, no_setup=True)
58
        return self.encoding_network.compute(x)
59
60
    def decode(self, x):
61
        """
62
        Decode given representation.
63
        """
64
        if not self.rep_dim:
65
            raise Exception("rep_dim must be set to decode.")
66
        if not self.decoding_network:
67
            self.decoding_network = NeuralNetwork(self.rep_dim, self.network_config)
68
            for layer in self.decoding_layers:
69
                self.decoding_network.stack_layer(layer, no_setup=True)
70
        return self.decoding_network.compute(x)
71