Completed
Push — master ( 48255b...bf2b0c )
by Raphael
01:13
created

neural_computation()   F

Complexity

Conditions 11

Size

Total Lines 34

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 11
c 1
b 0
f 0
dl 0
loc 34
rs 3.1764

How to fix   Complexity   

Complexity

Complex classes like neural_computation() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from theano.tensor.var import TensorVariable
5
6
def convert_to_theano_var(obj):
7
    """
8
    Convert neural vars to theano vars.
9
    :param obj: NeuralVariable or list or dict or tuple
10
    :return: theano var, test var, tensor found, neural var found
11
    """
12
    from deepy.layers.neural_var import NeuralVariable
13
    if type(obj) == tuple:
14
        return tuple(convert_to_theano_var(list(obj)))
15
    if type(obj) == list:
16
        unpacked_list = map(convert_to_theano_var, obj)
17
        normal_list = []
18
        test_list = []
19
        theano_var_found = False
20
        neural_var_found = False
21
        for normal_var, test_var, tensor_found, neural_found in unpacked_list:
22
            normal_list.append(normal_var)
23
            test_list.append(test_var)
24
            if tensor_found: theano_var_found = True
25
            if neural_found: neural_var_found = True
26
        return normal_list, test_list, theano_var_found, neural_var_found
27
    elif type(obj) == dict:
28
        normal_map = {}
29
        test_map = {}
30
        theano_var_found = False
31
        neural_var_found = False
32
        for key in obj:
33
            normal_var, test_var, tensor_found, neural_found = convert_to_theano_var(obj[key])
34
            normal_map[key] = normal_var
35
            test_map[key] = test_var
36
            if tensor_found: theano_var_found = True
37
            if neural_found: neural_var_found = True
38
        return normal_map, test_map, theano_var_found, neural_var_found
39
    elif type(obj) == NeuralVariable:
40
        return obj.tensor, obj.test_tensor, False, True
41
    elif type(obj) == TensorVariable:
42
        return obj, obj, True, False
43
    elif type(obj) == slice:
44
        normal_args = []
45
        test_args = []
46
        theano_var_found = False
47
        neural_var_found = False
48
        for arg in [obj.start, obj.stop, obj.step]:
49
            normal_var, test_var, tensor_found, neural_found = convert_to_theano_var(arg)
50
            normal_args.append(normal_var)
51
            test_args.append(test_var)
52
            if tensor_found: theano_var_found = True
53
            if neural_found: neural_var_found = True
54
        return slice(*normal_args), slice(*test_args), theano_var_found, neural_var_found
55
    else:
56
        return obj, obj, False, False
57
58
def convert_to_neural_var(obj, test_obj):
59
    """
60
    Convert object and a test object into neural var.
61
    :param obj: tensor or list or dict or tuple
62
    :param test_obj: NeuralVar or list or dict or tuple
63
    :return:
64
    """
65
    from theano.tensor.var import TensorVariable
66
    from deepy.layers.neural_var import NeuralVariable
67
    if type(obj) == list:
68
        return [convert_to_neural_var(*item) for item in zip(obj, test_obj)]
69
    elif type(obj) == tuple:
70
        return tuple(convert_to_neural_var(list(obj), list(test_obj)))
71
    elif type(obj) == dict:
72
        merged_map = {}
73
        for key in obj:
74
            merged_map[key] = convert_to_neural_var(obj[key], test_obj[key])
75
        return merged_map
76
    elif type(obj) == TensorVariable:
77
        return NeuralVariable(obj, test_obj, 0)
78
    else:
79
        return obj
80
81
def neural_computation(original_func, prefer_tensor=False):
82
    """
83
    An annotation to enable theano-based fucntions to be called with NeuralVar.
84
    :param original_func:
85
    :param prefer_tensor: a switch to return tensors when no inputs
86
    :return:
87
    """
88
89
    def wrapper(*args, **kwargs):
90
        normal_args, test_args, tensor_found_in_args, neural_found_in_args = convert_to_theano_var(args)
91
        normal_kwargs, test_kwargs, tensor_found_in_kwargs, neural_found_in_kwargs = convert_to_theano_var(kwargs)
92
93
        tensor_found = tensor_found_in_args or tensor_found_in_kwargs
94
        neural_found = neural_found_in_args or neural_found_in_kwargs
95
96
        if tensor_found and neural_found:
97
            raise Exception("Theano tensor variables can not be used together with neural variables.")
98
99
        normal_result = original_func(*normal_args, **normal_kwargs)
100
101
        if tensor_found or (not neural_found and prefer_tensor):
102
            # No neural variables are inputted, so output tensors
103
            return normal_result
104
        else:
105
            # Output neural variables, auto set output_dim
106
            test_result = original_func(*test_args, **test_kwargs)
107
            result_var = convert_to_neural_var(normal_result, test_result)
108
            if (isinstance(normal_result, TensorVariable) and
109
                    hasattr(normal_result.tag, "test_value") and
110
                    hasattr(normal_result.tag.test_value, "shape") and
111
                    normal_result.tag.test_value.shape):
112
                result_var.output_dim = normal_result.tag.test_value.shape[-1]
113
            return result_var
114
    return wrapper
115
116
def neural_computation_prefer_tensor(original_func):
117
    return neural_computation(original_func, prefer_tensor=True)