Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

convert_to_theano_var()   F

Complexity

Conditions 12

Size

Total Lines 40

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 12
dl 0
loc 40
rs 2.7855
c 1
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like convert_to_theano_var() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
def convert_to_theano_var(obj):
5
    """
6
    Convert neural vars to theano vars.
7
    :param obj: NeuralVariable or list or dict or tuple
8
    :return: theano var, test var, tensor found, neural var found
9
    """
10
    from theano.tensor.var import TensorVariable
11
    from deepy.layers.var import NeuralVariable
12
    if type(obj) == tuple:
13
        return tuple(convert_to_theano_var(list(obj)))
14
    if type(obj) == list:
15
        unpacked_list = map(convert_to_theano_var, obj)
16
        normal_list = []
17
        test_list = []
18
        theano_var_found = False
19
        neural_var_found = False
20
        for normal_var, test_var, tensor_found, neural_found in unpacked_list:
21
            normal_list.append(normal_var)
22
            test_list.append(test_var)
23
            if tensor_found: theano_var_found = True
24
            if neural_found: neural_var_found = True
25
        return normal_list, test_list, theano_var_found, neural_var_found
26
    elif type(obj) == dict:
27
        normal_map = {}
28
        test_map = {}
29
        theano_var_found = False
30
        neural_var_found = False
31
        for key in obj:
32
            normal_var, test_var, tensor_found, neural_found = convert_to_theano_var(obj[key])
33
            normal_map[key] = normal_var
34
            test_map[key] = test_var
35
            if tensor_found: theano_var_found = True
36
            if neural_found: neural_var_found = True
37
        return normal_map, test_map, theano_var_found, neural_var_found
38
    elif type(obj) == NeuralVariable:
39
        return obj.tensor, obj.test_tensor, False, True
40
    elif type(obj) == TensorVariable:
41
        return obj, obj, True, False
42
    else:
43
        return obj, obj, False, False
44
45
def convert_to_neural_var(obj, test_obj):
46
    """
47
    Convert object and a test object into neural var.
48
    :param obj: tensor or list or dict or tuple
49
    :param test_obj: NeuralVar or list or dict or tuple
50
    :return:
51
    """
52
    from theano.tensor.var import TensorVariable
53
    from deepy.layers.var import NeuralVariable
54
    if type(obj) == list:
55
        return [convert_to_neural_var(*item) for item in zip(obj, test_obj)]
56
    elif type(obj) == tuple:
57
        return tuple(convert_to_neural_var(list(obj), list(test_obj)))
58
    elif type(obj) == dict:
59
        merged_map = {}
60
        for key in obj:
61
            merged_map[key] = convert_to_neural_var(obj[key], test_obj[key])
62
        return merged_map
63
    elif type(obj) == TensorVariable:
64
        return NeuralVariable(obj, test_obj, 0)
65
    else:
66
        return obj
67
68
def neural_computation(original_func, prefer_tensor=False):
69
    """
70
    An annotation to enable theano-based fucntions to be called with NeuralVar.
71
    :param original_func:
72
    :param prefer_tensor: a switch to return tensors when no inputs
73
    :return:
74
    """
75
76
    def wrapper(*args, **kwargs):
77
78
        normal_args, test_args, tensor_found_in_args, neural_found_in_args = convert_to_theano_var(args)
79
        normal_kwargs, test_kwargs, tensor_found_in_kwargs, neural_found_in_kwargs = convert_to_theano_var(kwargs)
80
81
        tensor_found = tensor_found_in_args or tensor_found_in_kwargs
82
        neural_found = neural_found_in_args or neural_found_in_kwargs
83
84
        if tensor_found and neural_found:
85
            raise Exception("Theano tensor variables can not be used together with neural variables.")
86
87
        normal_result = original_func(*normal_args, **normal_kwargs)
88
89
        if tensor_found or (not neural_found and prefer_tensor):
90
            # No neural variables are inputted, so output tensors
91
            return normal_result
92
        else:
93
            # Output neural variables
94
            test_result = original_func(*test_args, **test_kwargs)
95
            return convert_to_neural_var(normal_result, test_result)
96
97
    return wrapper
98
99
def neural_computation_prefer_tensor(original_func):
100
    return neural_computation(original_func, prefer_tensor=True)