Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

Loop._build_loop_vars()   D

Complexity

Conditions 8

Size

Total Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 8
c 1
b 0
f 0
dl 0
loc 25
rs 4
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from deepy.utils import Scanner, MapDict
5
from dummy_loop_utils import get_dummy_args, finish_scan
6
import theano
7
8
class LoopVars(MapDict):
9
    """
10
    Variables inside the loop.
11
    """
12
13
class Loop(object):
14
15
    def __init__(self, sequences=None, outputs=None, non_sequences=None, block=None, **kwargs):
16
        """
17
        A loop function to support "with" grammar.
18
        """
19
        self._sequences = sequences if sequences else {}
20
        self._outputs = outputs if outputs else {}
21
        self._non_sequences = non_sequences if non_sequences else {}
22
        if not isinstance(self._sequences, dict) or not isinstance(self._outputs, dict) != dict or not isinstance(self._non_sequences, dict):
23
            raise Exception("Arguments of Loop shall be dicts.")
24
        self._block = block
25
        self._kwargs = kwargs
26
        self._loop_vars = None
27
        self._dummy_nodes = None
28
        self._scan_local_vars = None
29
        self._ordered_out_keys = []
30
        self._scan_outputs = None
31
32
    def _build_loop_vars(self):
33
        """
34
        Create inner loop variables.
35
        """
36
        from theano.tensor.var import TensorVariable
37
        from deepy.core.neural_var import NeuralVariable
38
        if not self._loop_vars:
39
            self._ordered_out_keys = self._outputs.keys()
40
            seq_keys = self._sequences.keys()
41
            filled_out_keys = [k for k in self._ordered_out_keys if self._outputs[k]]
42
            nonseq_keys = self._non_sequences.keys()
43
            dummy_tensors, self._scan_local_vars = get_dummy_args(
44
                sequences=[self._sequences[k].tensor for k in seq_keys],
45
                outputs_info=[self._outputs[k].tensor for k in self._ordered_out_keys],
46
                non_sequences=[self._non_sequences[k].tensor for k in nonseq_keys],
47
                **self._kwargs
48
            )
49
            dummy_map = dict(zip(seq_keys + filled_out_keys + nonseq_keys, dummy_tensors))
50
            arg_map = self._sequences.copy()
51
            arg_map.update(self._outputs)
52
            arg_map.update(self._non_sequences)
53
            self._loop_vars = LoopVars()
54
            for k, dummy_tensor in dummy_map.items():
55
                dummy_var = NeuralVariable(dummy_tensor, dim=arg_map[k].dim())
56
                self._loop_vars[k] = dummy_var
57
            # self._dummy_nodes = dict(self._loop_vars.items()[:])
58
59
60
    def __enter__(self):
61
        self._build_loop_vars()
62
        return self._loop_vars
63
64
    def __exit__(self, exc_type, exc_val, exc_tb):
65
        from neural_var import NeuralVariable
66
        output_tensors = []
67
        for k in self._ordered_out_keys:
68
            if k not in self._loop_vars:
69
                raise Exception("{} can not be found in loop vars.".format(k))
70
            output_tensors.append(self._loop_vars[k].tensor)
71
72
        result_tensors, updates = finish_scan(output_tensors, self._scan_local_vars)
73
        if self._block and updates:
74
            if type(updates) == dict:
75
                updates = updates.items()
76
            self._block.register_updates(*updates)
77
78
        outputs = MapDict()
79
        for k, tensor in zip(self._ordered_out_keys, result_tensors):
80
            out_var = NeuralVariable(tensor)
81
            if self._outputs[k] is not None:
82
                out_var.output_dim = self._outputs[k].dim()
83
            outputs[k] = out_var
84
        self._scan_outputs = outputs
85
86
    def _scan_step(self, vars):
87
        """
88
        Internal scan with dummy input variables.
89
        """
90
        from neural_var import NeuralVariable
91
        if not self._loop_vars:
92
            raise Exception("The loop is not initialized. To initialize the loop, use `with loop as vars`")
93
        replace_map = {}
94
        for k, var in vars.items():
95
            if var is not None:
96
                replace_map[self._dummy_nodes[k].tensor] = var.tensor
97
        outputs = {}
98
        for k in self._outputs:
99
            if k not in self._loop_vars:
100
                raise Exception("{} can not be found in loop vars.".format(k))
101
            output_node = theano.clone(self._loop_vars[k].tensor, replace_map)
102
            outputs[k] = NeuralVariable(output_node, self._loop_vars[k].dim())
103
        return outputs
104
105
    @property
106
    def outputs(self):
107
        """
108
        Get the output of the loop.
109
        :rtype: MapDict
110
        """
111
        if not self._scan_outputs:
112
            raise Exception("The loop is not executed.")
113
        else:
114
            return self._scan_outputs
115
116
    def get_outputs(self, *args):
117
        """
118
        Get the outputs of the loop.
119
        Return specific variables by passing the keys to the arguments.
120
        :rtype: MapDict
121
        """
122
        if args:
123
            output_vars = map(self._scan_outputs.get, args)
124
            if len(output_vars) == 1:
125
                return output_vars[0]
126
            else:
127
                return output_vars
128
        else:
129
            return self._scan_outputs
130
131
132