Completed
Push — master ( 15b7f6...48255b )
by Raphael
02:17
created

NeuralVariable.__getitem__()   B

Complexity

Conditions 7

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 1 Features 0
Metric Value
cc 7
dl 0
loc 13
rs 7.3333
c 2
b 1
f 0

1 Method

Rating   Name   Duplication   Size   Complexity  
A NeuralVariable.getitem_wrapper() 0 5 2
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from layer import NeuralLayer
5
from deepy.utils.decorations import neural_computation, convert_to_theano_var
6
7
8
class NeuralVariable(NeuralLayer):
9
    """
10
    Create a constant layer with tensors.
11
    """
12
13
    def __init__(self, tensor, test_tensor=None, dim=0):
14
        """
15
        Create a tensor layer.
16
        """
17
        super(NeuralVariable, self).__init__("const")
18
        self.output_dim = dim
19
        self.tensor = tensor
20
        self.test_tensor = tensor if not test_tensor else test_tensor
21
        self.initialize(0)
22
23
    def __getitem__(self, index):
24
        @neural_computation
25
        def getitem_wrapper(t, index):
26
            if type(index) == list:
27
                index = tuple(index)
28
            return t.__getitem__(index)
29
        ret = getitem_wrapper(self, index)
30
        if (hasattr(ret.tensor, 'tag') and hasattr(ret.tensor.tag, 'test_value')
31
            and ret.tensor.tag.test_value is not None and len(ret.tensor.tag.test_value.shape) > 0):
32
            ret.output_dim = ret.tensor.tag.test_value.shape[-1]
33
        else:
34
            ret.output_dim = self.dim()
35
        return ret
36
37
    def __call__(self, *args, **kwargs):
38
        normal_args, test_args, tensor_found_in_args, neural_found_in_args = convert_to_theano_var(args)
39
        normal_kwargs, test_kwargs, tensor_found_in_kwargs, neural_found_in_kwargs = convert_to_theano_var(kwargs)
40
41
        tensor_found = tensor_found_in_args or tensor_found_in_kwargs
42
43
        if tensor_found:
44
            raise Exception("Theano tensor variables can not be used together with neural variables.")
45
46
        return NeuralVariable(self.tensor(*normal_args, **normal_kwargs), self.test_tensor(*test_args, **test_kwargs), dim=self.dim())
47
48
    def __getattr__(self, name):
49
        return NeuralVariable(getattr(self.tensor, name), getattr(self.test_tensor, name), dim=self.dim())
50
51
    def apply(self, func, dim=None):
52
        """
53
        Apply a function to tensors.
54
        """
55
        output_dim = dim if dim else self.output_dim
56
        return NeuralVariable(func(self.tensor), func(self.test_tensor), output_dim)
57
58
    def compute_tensor(self, x):
59
        return self.tensor
60
61
    def compute_test_tesnor(self, x):
62
        return self.test_tensor
63
64
    def set_test_value(self, value):
65
        self.tensor.tag.test_value = value
66
67
    def dim(self):
68
        return self.output_dim
69
70
    # def shape(self, dim_index):
71
    #     return NeuralVariable(self.tensor.shape[dim_index], self.test_tensor.shape[dim_index])
72
73
    def _other_tensor(self, other):
74
        return  other.tensor if isinstance(other, NeuralVariable) else other
75
76
    def _other_test_tensor(self, other):
77
        return other.test_tensor if isinstance(other, NeuralVariable) else other
78
79
    def __add__(self, other):
80
81
        return NeuralVariable(self.tensor + self._other_tensor(other), self.test_tensor + self._other_test_tensor(other), dim=self.dim())
82
83
    def __sub__(self, other):
84
        return NeuralVariable(self.tensor - self._other_tensor(other), self.test_tensor - self._other_test_tensor(other), dim=self.dim())
85
86
    def __mul__(self, other):
87
        return NeuralVariable(self.tensor * self._other_tensor(other), self.test_tensor * self._other_test_tensor(other), dim=self.dim())
88
89
    def __div__(self, other):
90
        return NeuralVariable(self.tensor / self._other_tensor(other), self.test_tensor / self._other_test_tensor(other), dim=self.dim())
91
92
    @property
93
    def test_value(self):
94
        if hasattr(self.tensor.tag, 'test_value'):
95
            return self.tensor.tag.test_value
96
        else:
97
            return None
98
99
    @property
100
    def tv(self):
101
        return self.test_value
102
103
    @property
104
    def ts(self):
105
        if self.test_value is not None:
106
            return self.test_value.shape
107
        else:
108
            return None