Completed
Push — master ( 5bbe2a...9d73f5 )
by Raphael
01:33
created

deepy.utils.Scanner   A

Complexity

Total Complexity 28

Size/Duplication

Total Lines 53
Duplicated Lines 0 %
Metric Value
dl 0
loc 53
rs 10
wmc 28

3 Methods

Rating   Name   Duplication   Size   Complexity  
D _func_wrapper() 0 22 11
C __init__() 0 14 13
A compute() 0 9 4
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import theano
5
6
class Scanner(object):
7
    """
8
    Call `theano.scan` with dictionary parameters.
9
    """
10
11
    def __init__(self, func, sequences=None, outputs_info=None, non_sequences=None, **kwargs):
12
        if ((sequences and type(sequences) != dict) or
13
            (outputs_info and type(outputs_info) != dict) or
14
            (non_sequences and type(non_sequences) != dict)):
15
            raise Exception("The parameter `sequences`, `outputs_info`, `non_sequences` must be dict.")
16
        self._func = func
17
18
        self._sequence_keys = sequences.keys() if sequences else []
19
        self._sequence_values = sequences.values() if sequences else []
20
        self._output_keys = outputs_info.keys() if outputs_info else []
21
        self._output_values = outputs_info.values() if outputs_info else []
22
        self._non_sequence_keys = non_sequences.keys() if non_sequences else []
23
        self._non_sequence_values = non_sequences.values() if non_sequences else []
24
        self._kwargs = kwargs
25
26
27
    def _func_wrapper(self, *vars):
28
        all_values = self._sequence_values + self._output_values + self._non_sequence_values
29
        all_keys = self._sequence_keys + self._output_keys + self._non_sequence_keys
30
        valid_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is not None]
31
        none_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is None]
32
33
        dict_param = dict(zip(valid_keys, vars))
34
        dict_param.update(dict(zip(none_keys, [None for _ in range(len(none_keys))])))
35
        retval = self._func(dict_param)
36
        if type(retval) == tuple:
37
            dict_retval, updates = retval
38
        else:
39
            dict_retval, updates = retval, None
40
        if type(dict_retval) != dict:
41
            raise Exception("The return value of scanner function must be a dict")
42
        final_retval = [dict_retval[k] for k in self._output_keys]
43
        if len(final_retval) == 1:
44
            final_retval = final_retval[0]
45
        if updates:
46
            return final_retval, updates
47
        else:
48
            return final_retval
49
50
    def compute(self):
51
        results, updates = theano.scan(self._func_wrapper,
52
                            sequences=filter(lambda t: t is not None, self._sequence_values),
53
                            outputs_info=self._output_values,
54
                            non_sequences=filter(lambda t: t is not None, self._non_sequence_values),
55
                            **self._kwargs)
56
        if type(results) != list:
57
            results = [results]
58
        return dict(zip(self._output_keys, results)), updates
59