Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

var()   F

Complexity

Conditions 14

Size

Total Lines 45

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 14
c 3
b 0
f 0
dl 0
loc 45
rs 2.7581

How to fix   Complexity   

Complexity

Complex classes like var() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano.tensor as TT
5
from theano.ifelse import ifelse as theano_ifelse
6
7
from deepy.core.tensor_conversion import neural_computation
8
from deepy.layers.concat import Concatenate
9
from wrapper import deepy_tensor
10
11
12
def concat(vars, axis=-1):
13
    """
14
    A shortcut for concatenation.
15
    """
16
    return concatenate(vars, axis)
17
18
@neural_computation
19
def reverse(tensor, axis=-1):
20
    ndim = tensor.ndim
21
    selectors = [slice(None)] * ndim
22
    selectors[axis] = slice(None, None, -1)
23
    ret = tensor[selectors]
24
    if hasattr(tensor.tag, "last_dim"):
25
        ret.tag.last_dim = tensor.tag.last_dim
26
    return ret
27
28
@neural_computation
29
def activate(var, method):
30
    """
31
    An activation function.
32
    :param var: input var
33
    :param method: type of activation, such as `relu`,`tanh`,`sigmoid`
34
    """
35
    from activations import get_activation
36
    return get_activation(method)(var)
37
38
def concatenate(vars, axis=-1):
39
    """
40
    A utility function of concatenate.
41
    """
42
    from deepy.core.neural_var import NeuralVariable
43
    if isinstance(vars[0], NeuralVariable):
44
        concat_var = Concatenate(axis=axis).compute(*vars)
45
        if axis == -1 or axis == vars[0].tensor.ndim - 1:
46
            concat_var.output_dim = sum([x.output_dim for x in vars], 0)
47
    else:
48
        concat_var = TT.concatenate(vars, axis)
49
    return concat_var
50
51
@neural_computation
52
def ifelse(condition, then_branch, else_branch):
53
    return theano_ifelse(condition, then_branch, else_branch)
54
55
56
def apply(func, *args, **kwargs):
57
    from deepy.core.neural_var import NeuralVariable
58
    dim = kwargs['dim'] if 'dim' in kwargs else args[0].dim()
59
    return NeuralVariable(func(*[x.tensor for x in args]), dim)
60
61
def repeat(*args, **kwargs):
62
    return deepy_tensor.repeat(*args, **kwargs)
63
64
def vars(*tensor_types):
65
    """
66
    Create multiple variables without specifying last dimension and shape.
67
    :rtype: list of deepy.core.neural_var.NeuralVariable
68
    """
69
    return map(var, tensor_types)
70
71
72
def var(tensor_type, last_dim=0, test_shape=None):
73
    """
74
    Wrap a Theano tensor into the variable for defining neural network.
75
    :param last_dim: last dimension of tensor, 0 indicates that the last dimension is flexible
76
    :rtype: deepy.core.neural_var.NeuralVariable
77
    """
78
    # Create tensor
79
    from deepy.core.neural_var import NeuralVariable
80
    from deepy.core.env import env
81
    from theano.tensor.var import TensorVariable
82
    if isinstance(tensor_type, NeuralVariable):
83
        var = tensor_type
84
        if last_dim != 0:
85
            var.output_dim = last_dim
86
    elif isinstance(tensor_type, TensorVariable):
87
        var = NeuralVariable(tensor_type, dim=last_dim)
88
    elif isinstance(tensor_type, str):
89
        theano_tensor = getattr(TT, tensor_type)()
90
        var = NeuralVariable(theano_tensor, dim=last_dim)
91
    else:
92
        raise Exception("tensor_type shall be a string or a NeuralVariable")
93
    # Set test value
94
    if test_shape:
95
        if type(test_shape) != list and type(test_shape) != tuple:
96
            # May be it's a value
97
            var.set_test_value(test_shape)
98
        else:
99
            test_val = env.numpy_rand.rand(*test_shape)
100
            if len(test_shape) > 0:
101
                test_val = test_val.astype(var.tensor.dtype)
102
            elif var.tensor.dtype.startswith("int"):
103
                test_val = 1
104
            var.set_test_value(test_val)
105
    else:
106
        # Create a general test_shape
107
        dims = [(d + 1) * 3 for d in range(var.tensor.ndim)]
108
        if var.dim() != 0:
109
            dims[-1] = var.dim()
110
        test_val = env.numpy_rand.rand(*dims)
111
        if len(dims) > 0:
112
            test_val = test_val.astype(var.tensor.dtype)
113
        elif var.tensor.dtype.startswith("int"):
114
            test_val = 1
115
        var.set_test_value(test_val)
116
    return var
117
118
119
def is_neural_var(var):
120
    from deepy.core.neural_var import NeuralVariable
121
    return isinstance(var, NeuralVariable)
122
123
def is_theano_var(var):
124
    from theano.tensor.var import TensorVariable
125
    return isinstance(var, TensorVariable)