neural_computation()   F
last analyzed

Complexity

Conditions 11

Size

Total Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 3
Bugs 0 Features 0
Metric Value
cc 11
c 3
b 0
f 0
dl 0
loc 33
rs 3.1764

How to fix   Complexity   

Complexity

Complex classes like neural_computation() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from theano.tensor.var import TensorVariable
5
from deepy.utils.map_dict import MapDict
6
7
def convert_to_theano_var(obj):
8
    """
9
    Convert neural vars to theano vars.
10
    :param obj: NeuralVariable or list or dict or tuple
11
    :return: theano var, test var, tensor found, neural var found
12
    """
13
    from deepy.core.neural_var import NeuralVariable
14
    if type(obj) == tuple:
15
        return tuple(convert_to_theano_var(list(obj)))
16
    if type(obj) == list:
17
        unpacked_list = map(convert_to_theano_var, obj)
18
        normal_list = []
19
        test_list = []
20
        theano_var_found = False
21
        neural_var_found = False
22
        for normal_var, tensor_found, neural_found in unpacked_list:
23
            normal_list.append(normal_var)
24
            if tensor_found: theano_var_found = True
25
            if neural_found: neural_var_found = True
26
        return normal_list, theano_var_found, neural_var_found
27
    elif type(obj) == dict:
28
        normal_map = {}
29
        theano_var_found = False
30
        neural_var_found = False
31
        for key in obj:
32
            normal_var, tensor_found, neural_found = convert_to_theano_var(obj[key])
33
            normal_map[key] = normal_var
34
            if tensor_found: theano_var_found = True
35
            if neural_found: neural_var_found = True
36
        return normal_map, theano_var_found, neural_var_found
37
    elif type(obj) == MapDict:
38
        normal_map = {}
39
        theano_var_found = False
40
        neural_var_found = False
41
        for key in obj:
42
            normal_var, tensor_found, neural_found = convert_to_theano_var(obj[key])
43
            normal_map[key] = normal_var
44
            if tensor_found: theano_var_found = True
45
            if neural_found: neural_var_found = True
46
        return MapDict(normal_map), theano_var_found, neural_var_found
47
    elif type(obj) == NeuralVariable:
48
        theano_tensor = obj.tensor
49
        theano_tensor.tag.last_dim = obj.dim()
50
        return theano_tensor, False, True
51
    elif type(obj) == TensorVariable:
52
        return obj, True, False
53
    elif type(obj) == slice:
54
        normal_args = []
55
        theano_var_found = False
56
        neural_var_found = False
57
        for arg in [obj.start, obj.stop, obj.step]:
58
            normal_var, tensor_found, neural_found = convert_to_theano_var(arg)
59
            normal_args.append(normal_var)
60
            if tensor_found: theano_var_found = True
61
            if neural_found: neural_var_found = True
62
        return slice(*normal_args), theano_var_found, neural_var_found
63
    else:
64
        return obj, False, False
65
66
def convert_to_neural_var(obj):
67
    """
68
    Convert object and a test object into neural var.
69
    :param obj: tensor or list or dict or tuple
70
    :param test_obj: NeuralVar or list or dict or tuple
71
    :return:
72
    """
73
    from theano.tensor.var import TensorVariable
74
    from deepy.core.neural_var import NeuralVariable
75
    if type(obj) == list:
76
        return [convert_to_neural_var(item) for item in obj]
77
    elif type(obj) == tuple:
78
        return tuple(convert_to_neural_var(list(obj)))
79
    elif type(obj) == dict:
80
        merged_map = {}
81
        for key in obj:
82
            merged_map[key] = convert_to_neural_var(obj[key])
83
        return merged_map
84
    elif type(obj) == MapDict:
85
        merged_map = {}
86
        for key in obj:
87
            merged_map[key] = convert_to_neural_var(obj[key])
88
        return MapDict(merged_map)
89
    elif type(obj) == TensorVariable:
90
        deepy_var = NeuralVariable(obj)
91
        if hasattr(obj, 'tag') and hasattr(obj.tag, 'last_dim'):
92
            deepy_var.output_dim = obj.tag.last_dim
93
        return deepy_var
94
    else:
95
        return obj
96
97
def neural_computation(original_func, prefer_tensor=False):
98
    """
99
    An annotation to enable theano-based fucntions to be called with NeuralVar.
100
    :param original_func:
101
    :param prefer_tensor: a switch to return tensors when no inputs
102
    :return:
103
    """
104
105
    def wrapper(*args, **kwargs):
106
        normal_args, tensor_found_in_args, neural_found_in_args = convert_to_theano_var(args)
107
        normal_kwargs, tensor_found_in_kwargs, neural_found_in_kwargs = convert_to_theano_var(kwargs)
108
109
        tensor_found = tensor_found_in_args or tensor_found_in_kwargs
110
        neural_found = neural_found_in_args or neural_found_in_kwargs
111
112
        if tensor_found and neural_found:
113
            raise Exception("Theano tensor variables can not be used together with neural variables.")
114
115
        normal_result = original_func(*normal_args, **normal_kwargs)
116
117
        if tensor_found or (not neural_found and prefer_tensor):
118
            # No neural variables are inputted, so output tensors
119
            return normal_result
120
        else:
121
            # Output neural variables, auto set output_dim
122
            result_var = convert_to_neural_var(normal_result)
123
            if (isinstance(normal_result, TensorVariable) and
124
                    hasattr(normal_result.tag, "test_value") and
125
                    hasattr(normal_result.tag.test_value, "shape") and
126
                    normal_result.tag.test_value.shape):
127
                result_var.output_dim = normal_result.tag.test_value.shape[-1]
128
            return result_var
129
    return wrapper
130
131
def neural_computation_prefer_tensor(original_func):
132
    return neural_computation(original_func, prefer_tensor=True)