1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
from deepy.utils import Scanner, MapDict |
5
|
|
|
from dummy_loop_utils import get_dummy_args, finish_scan |
6
|
|
|
import theano |
7
|
|
|
|
8
|
|
|
class LoopVars(MapDict): |
9
|
|
|
""" |
10
|
|
|
Variables inside the loop. |
11
|
|
|
""" |
12
|
|
|
|
13
|
|
|
class Loop(object): |
14
|
|
|
|
15
|
|
|
def __init__(self, sequences=None, outputs=None, non_sequences=None, block=None, **kwargs): |
16
|
|
|
""" |
17
|
|
|
A loop function to support "with" grammar. |
18
|
|
|
""" |
19
|
|
|
self._sequences = sequences if sequences else {} |
20
|
|
|
self._outputs = outputs if outputs else {} |
21
|
|
|
self._non_sequences = non_sequences if non_sequences else {} |
22
|
|
|
if not isinstance(self._sequences, dict) or not isinstance(self._outputs, dict) != dict or not isinstance(self._non_sequences, dict): |
23
|
|
|
raise Exception("Arguments of Loop shall be dicts.") |
24
|
|
|
self._block = block |
25
|
|
|
self._kwargs = kwargs |
26
|
|
|
self._loop_vars = None |
27
|
|
|
self._dummy_nodes = None |
28
|
|
|
self._scan_local_vars = None |
29
|
|
|
self._ordered_out_keys = [] |
30
|
|
|
self._scan_outputs = None |
31
|
|
|
|
32
|
|
|
def _build_loop_vars(self): |
33
|
|
|
""" |
34
|
|
|
Create inner loop variables. |
35
|
|
|
""" |
36
|
|
|
from theano.tensor.var import TensorVariable |
37
|
|
|
from deepy.core.neural_var import NeuralVariable |
38
|
|
|
if not self._loop_vars: |
39
|
|
|
self._ordered_out_keys = self._outputs.keys() |
40
|
|
|
seq_keys = self._sequences.keys() |
41
|
|
|
filled_out_keys = [k for k in self._ordered_out_keys if self._outputs[k]] |
42
|
|
|
nonseq_keys = self._non_sequences.keys() |
43
|
|
|
dummy_tensors, self._scan_local_vars = get_dummy_args( |
44
|
|
|
sequences=[self._sequences[k].tensor for k in seq_keys], |
45
|
|
|
outputs_info=[self._outputs[k].tensor for k in self._ordered_out_keys], |
46
|
|
|
non_sequences=[self._non_sequences[k].tensor for k in nonseq_keys], |
47
|
|
|
**self._kwargs |
48
|
|
|
) |
49
|
|
|
dummy_map = dict(zip(seq_keys + filled_out_keys + nonseq_keys, dummy_tensors)) |
50
|
|
|
arg_map = self._sequences.copy() |
51
|
|
|
arg_map.update(self._outputs) |
52
|
|
|
arg_map.update(self._non_sequences) |
53
|
|
|
self._loop_vars = LoopVars() |
54
|
|
|
for k, dummy_tensor in dummy_map.items(): |
55
|
|
|
dummy_var = NeuralVariable(dummy_tensor, dim=arg_map[k].dim()) |
56
|
|
|
self._loop_vars[k] = dummy_var |
57
|
|
|
# self._dummy_nodes = dict(self._loop_vars.items()[:]) |
58
|
|
|
|
59
|
|
|
|
60
|
|
|
def __enter__(self): |
61
|
|
|
self._build_loop_vars() |
62
|
|
|
return self._loop_vars |
63
|
|
|
|
64
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
65
|
|
|
from neural_var import NeuralVariable |
66
|
|
|
output_tensors = [] |
67
|
|
|
for k in self._ordered_out_keys: |
68
|
|
|
if k not in self._loop_vars: |
69
|
|
|
raise Exception("{} can not be found in loop vars.".format(k)) |
70
|
|
|
output_tensors.append(self._loop_vars[k].tensor) |
71
|
|
|
|
72
|
|
|
result_tensors, updates = finish_scan(output_tensors, self._scan_local_vars) |
73
|
|
|
if self._block and updates: |
74
|
|
|
if type(updates) == dict: |
75
|
|
|
updates = updates.items() |
76
|
|
|
self._block.register_updates(*updates) |
77
|
|
|
|
78
|
|
|
outputs = MapDict() |
79
|
|
|
for k, tensor in zip(self._ordered_out_keys, result_tensors): |
80
|
|
|
out_var = NeuralVariable(tensor) |
81
|
|
|
if self._outputs[k] is not None: |
82
|
|
|
out_var.output_dim = self._outputs[k].dim() |
83
|
|
|
outputs[k] = out_var |
84
|
|
|
self._scan_outputs = outputs |
85
|
|
|
|
86
|
|
|
def _scan_step(self, vars): |
87
|
|
|
""" |
88
|
|
|
Internal scan with dummy input variables. |
89
|
|
|
""" |
90
|
|
|
from neural_var import NeuralVariable |
91
|
|
|
if not self._loop_vars: |
92
|
|
|
raise Exception("The loop is not initialized. To initialize the loop, use `with loop as vars`") |
93
|
|
|
replace_map = {} |
94
|
|
|
for k, var in vars.items(): |
95
|
|
|
if var is not None: |
96
|
|
|
replace_map[self._dummy_nodes[k].tensor] = var.tensor |
97
|
|
|
outputs = {} |
98
|
|
|
for k in self._outputs: |
99
|
|
|
if k not in self._loop_vars: |
100
|
|
|
raise Exception("{} can not be found in loop vars.".format(k)) |
101
|
|
|
output_node = theano.clone(self._loop_vars[k].tensor, replace_map) |
102
|
|
|
outputs[k] = NeuralVariable(output_node, self._loop_vars[k].dim()) |
103
|
|
|
return outputs |
104
|
|
|
|
105
|
|
|
@property |
106
|
|
|
def outputs(self): |
107
|
|
|
""" |
108
|
|
|
Get the output of the loop. |
109
|
|
|
:rtype: MapDict |
110
|
|
|
""" |
111
|
|
|
if not self._scan_outputs: |
112
|
|
|
raise Exception("The loop is not executed.") |
113
|
|
|
else: |
114
|
|
|
return self._scan_outputs |
115
|
|
|
|
116
|
|
|
def get_outputs(self, *args): |
117
|
|
|
""" |
118
|
|
|
Get the outputs of the loop. |
119
|
|
|
Return specific variables by passing the keys to the arguments. |
120
|
|
|
:rtype: MapDict |
121
|
|
|
""" |
122
|
|
|
if args: |
123
|
|
|
output_vars = map(self._scan_outputs.get, args) |
124
|
|
|
if len(output_vars) == 1: |
125
|
|
|
return output_vars[0] |
126
|
|
|
else: |
127
|
|
|
return output_vars |
128
|
|
|
else: |
129
|
|
|
return self._scan_outputs |
130
|
|
|
|
131
|
|
|
|
132
|
|
|
|