Completed
Push — master ( 5bbe2a...9d73f5 )
by Raphael
01:33
created

deepy.utils.Scanner.__init__()   C

Complexity

Conditions 13

Size

Total Lines 14

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 13
dl 0
loc 14
rs 5.2937

How to fix   Complexity   

Complexity

Complex classes like deepy.utils.Scanner.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano
5
6
class Scanner(object):
7
    """
8
    Call `theano.scan` with dictionary parameters.
9
    """
10
11
    def __init__(self, func, sequences=None, outputs_info=None, non_sequences=None, **kwargs):
12
        if ((sequences and type(sequences) != dict) or
13
            (outputs_info and type(outputs_info) != dict) or
14
            (non_sequences and type(non_sequences) != dict)):
15
            raise Exception("The parameter `sequences`, `outputs_info`, `non_sequences` must be dict.")
16
        self._func = func
17
18
        self._sequence_keys = sequences.keys() if sequences else []
19
        self._sequence_values = sequences.values() if sequences else []
20
        self._output_keys = outputs_info.keys() if outputs_info else []
21
        self._output_values = outputs_info.values() if outputs_info else []
22
        self._non_sequence_keys = non_sequences.keys() if non_sequences else []
23
        self._non_sequence_values = non_sequences.values() if non_sequences else []
24
        self._kwargs = kwargs
25
26
27
    def _func_wrapper(self, *vars):
28
        all_values = self._sequence_values + self._output_values + self._non_sequence_values
29
        all_keys = self._sequence_keys + self._output_keys + self._non_sequence_keys
30
        valid_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is not None]
31
        none_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is None]
32
33
        dict_param = dict(zip(valid_keys, vars))
34
        dict_param.update(dict(zip(none_keys, [None for _ in range(len(none_keys))])))
35
        retval = self._func(dict_param)
36
        if type(retval) == tuple:
37
            dict_retval, updates = retval
38
        else:
39
            dict_retval, updates = retval, None
40
        if type(dict_retval) != dict:
41
            raise Exception("The return value of scanner function must be a dict")
42
        final_retval = [dict_retval[k] for k in self._output_keys]
43
        if len(final_retval) == 1:
44
            final_retval = final_retval[0]
45
        if updates:
46
            return final_retval, updates
47
        else:
48
            return final_retval
49
50
    def compute(self):
51
        results, updates = theano.scan(self._func_wrapper,
52
                            sequences=filter(lambda t: t is not None, self._sequence_values),
53
                            outputs_info=self._output_values,
54
                            non_sequences=filter(lambda t: t is not None, self._non_sequence_values),
55
                            **self._kwargs)
56
        if type(results) != list:
57
            results = [results]
58
        return dict(zip(self._output_keys, results)), updates
59