Scanner._func_wrapper()   F
last analyzed

Complexity

Conditions 18

Size

Total Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 18
c 2
b 0
f 0
dl 0
loc 36
rs 2.7087

How to fix   Complexity   

Complexity

Complex classes like Scanner._func_wrapper() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano
5
from map_dict import MapDict
6
7
class Scanner(object):
8
    """
9
    Call `theano.scan` with dictionary parameters.
10
    """
11
12
    def __init__(self, func, sequences=None, outputs_info=None, non_sequences=None, neural_computation=False, **kwargs):
13
        if ((sequences and type(sequences) != dict) or
14
            (outputs_info and type(outputs_info) != dict) or
15
            (non_sequences and type(non_sequences) != dict)):
16
            raise Exception("The parameter `sequences`, `outputs_info`, `non_sequences` must be dict.")
17
        self._func = func
18
19
        self._sequence_keys = sequences.keys() if sequences else []
20
        self._sequence_values = sequences.values() if sequences else []
21
        self._output_keys = outputs_info.keys() if outputs_info else []
22
        self._output_values = outputs_info.values() if outputs_info else []
23
        self._non_sequence_keys = non_sequences.keys() if non_sequences else []
24
        self._non_sequence_values = non_sequences.values() if non_sequences else []
25
        self._kwargs = kwargs
26
        self._neural_computation = neural_computation
27
        self._input_dim_list = []
28
        self._output_dim_map = {}
29
30
    def _func_wrapper(self, *vars):
31
        from deepy.core.tensor_conversion import convert_to_theano_var, convert_to_neural_var
32
        all_values = self._sequence_values + self._output_values + self._non_sequence_values
33
        all_keys = self._sequence_keys + self._output_keys + self._non_sequence_keys
34
        valid_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is not None]
35
        none_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is None]
36
        if self._neural_computation:
37
            for var, last_dim in zip(vars, self._input_dim_list):
38
                var.tag.last_dim = last_dim
39
40
        dict_param = MapDict(zip(valid_keys, vars))
41
        dict_param.update(MapDict(zip(none_keys, [None for _ in range(len(none_keys))])))
42
        if self._neural_computation:
43
            dict_param = convert_to_neural_var(dict_param)
44
        retval = self._func(dict_param)
45
        if type(retval) == tuple:
46
            dict_retval, updates = retval
47
        else:
48
            dict_retval, updates = retval, None
49
        if self._neural_computation:
50
            if isinstance(dict_retval, dict):
51
                for k, var in dict_retval.items():
52
                    self._output_dim_map[k] = var.dim()
53
            updates, _, _ = convert_to_theano_var(updates)
54
            dict_retval, _, _ = convert_to_theano_var(dict_retval)
55
        if type(dict_retval) == MapDict:
56
            dict_retval = dict(dict_retval.items())
57
        if type(dict_retval) != dict:
58
            raise Exception("The return value of scanner function must be a dict")
59
        final_retval = [dict_retval[k] for k in self._output_keys]
60
        if len(final_retval) == 1:
61
            final_retval = final_retval[0]
62
        if updates:
63
            return final_retval, updates
64
        else:
65
            return final_retval
66
67
    def compute(self):
68
        from deepy.core.tensor_conversion import convert_to_theano_var, convert_to_neural_var
69
        if self._neural_computation:
70
            self._input_dim_list = []
71
            for tensor in sum([self._sequence_values, self._output_values, self._non_sequence_values], []):
72
                last_dim = tensor.tag.last_dim if tensor and hasattr(tensor.tag, 'last_dim') else None
73
                self._input_dim_list.append(last_dim)
74
        results, updates = theano.scan(self._func_wrapper,
75
                            sequences=filter(lambda t: t is not None, self._sequence_values),
76
                            outputs_info=self._output_values,
77
                            non_sequences=filter(lambda t: t is not None, self._non_sequence_values),
78
                            **self._kwargs)
79
        if type(results) != list:
80
            results = [results]
81
        result_dict = MapDict(zip(self._output_keys, results))
82
        if self._neural_computation:
83
            result_dict = convert_to_neural_var(result_dict)
84
            for k in result_dict:
85
                if k in self._output_dim_map:
86
                    result_dict[k].output_dim = self._output_dim_map[k]
87
        return result_dict, updates