|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
|
|
4
|
|
|
import theano |
|
5
|
|
|
from map_dict import MapDict |
|
6
|
|
|
|
|
7
|
|
|
class Scanner(object): |
|
8
|
|
|
""" |
|
9
|
|
|
Call `theano.scan` with dictionary parameters. |
|
10
|
|
|
""" |
|
11
|
|
|
|
|
12
|
|
|
def __init__(self, func, sequences=None, outputs_info=None, non_sequences=None, neural_computation=False, **kwargs): |
|
13
|
|
|
if ((sequences and type(sequences) != dict) or |
|
14
|
|
|
(outputs_info and type(outputs_info) != dict) or |
|
15
|
|
|
(non_sequences and type(non_sequences) != dict)): |
|
16
|
|
|
raise Exception("The parameter `sequences`, `outputs_info`, `non_sequences` must be dict.") |
|
17
|
|
|
self._func = func |
|
18
|
|
|
|
|
19
|
|
|
self._sequence_keys = sequences.keys() if sequences else [] |
|
20
|
|
|
self._sequence_values = sequences.values() if sequences else [] |
|
21
|
|
|
self._output_keys = outputs_info.keys() if outputs_info else [] |
|
22
|
|
|
self._output_values = outputs_info.values() if outputs_info else [] |
|
23
|
|
|
self._non_sequence_keys = non_sequences.keys() if non_sequences else [] |
|
24
|
|
|
self._non_sequence_values = non_sequences.values() if non_sequences else [] |
|
25
|
|
|
self._kwargs = kwargs |
|
26
|
|
|
self._neural_computation = neural_computation |
|
27
|
|
|
self._input_dim_list = [] |
|
28
|
|
|
self._output_dim_map = {} |
|
29
|
|
|
|
|
30
|
|
|
def _func_wrapper(self, *vars): |
|
31
|
|
|
from deepy.core.tensor_conversion import convert_to_theano_var, convert_to_neural_var |
|
32
|
|
|
all_values = self._sequence_values + self._output_values + self._non_sequence_values |
|
33
|
|
|
all_keys = self._sequence_keys + self._output_keys + self._non_sequence_keys |
|
34
|
|
|
valid_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is not None] |
|
35
|
|
|
none_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is None] |
|
36
|
|
|
if self._neural_computation: |
|
37
|
|
|
for var, last_dim in zip(vars, self._input_dim_list): |
|
38
|
|
|
var.tag.last_dim = last_dim |
|
39
|
|
|
|
|
40
|
|
|
dict_param = MapDict(zip(valid_keys, vars)) |
|
41
|
|
|
dict_param.update(MapDict(zip(none_keys, [None for _ in range(len(none_keys))]))) |
|
42
|
|
|
if self._neural_computation: |
|
43
|
|
|
dict_param = convert_to_neural_var(dict_param) |
|
44
|
|
|
retval = self._func(dict_param) |
|
45
|
|
|
if type(retval) == tuple: |
|
46
|
|
|
dict_retval, updates = retval |
|
47
|
|
|
else: |
|
48
|
|
|
dict_retval, updates = retval, None |
|
49
|
|
|
if self._neural_computation: |
|
50
|
|
|
if isinstance(dict_retval, dict): |
|
51
|
|
|
for k, var in dict_retval.items(): |
|
52
|
|
|
self._output_dim_map[k] = var.dim() |
|
53
|
|
|
updates, _, _ = convert_to_theano_var(updates) |
|
54
|
|
|
dict_retval, _, _ = convert_to_theano_var(dict_retval) |
|
55
|
|
|
if type(dict_retval) == MapDict: |
|
56
|
|
|
dict_retval = dict(dict_retval.items()) |
|
57
|
|
|
if type(dict_retval) != dict: |
|
58
|
|
|
raise Exception("The return value of scanner function must be a dict") |
|
59
|
|
|
final_retval = [dict_retval[k] for k in self._output_keys] |
|
60
|
|
|
if len(final_retval) == 1: |
|
61
|
|
|
final_retval = final_retval[0] |
|
62
|
|
|
if updates: |
|
63
|
|
|
return final_retval, updates |
|
64
|
|
|
else: |
|
65
|
|
|
return final_retval |
|
66
|
|
|
|
|
67
|
|
|
def compute(self): |
|
68
|
|
|
from deepy.core.tensor_conversion import convert_to_theano_var, convert_to_neural_var |
|
69
|
|
|
if self._neural_computation: |
|
70
|
|
|
self._input_dim_list = [] |
|
71
|
|
|
for tensor in sum([self._sequence_values, self._output_values, self._non_sequence_values], []): |
|
72
|
|
|
last_dim = tensor.tag.last_dim if tensor and hasattr(tensor.tag, 'last_dim') else None |
|
73
|
|
|
self._input_dim_list.append(last_dim) |
|
74
|
|
|
results, updates = theano.scan(self._func_wrapper, |
|
75
|
|
|
sequences=filter(lambda t: t is not None, self._sequence_values), |
|
76
|
|
|
outputs_info=self._output_values, |
|
77
|
|
|
non_sequences=filter(lambda t: t is not None, self._non_sequence_values), |
|
78
|
|
|
**self._kwargs) |
|
79
|
|
|
if type(results) != list: |
|
80
|
|
|
results = [results] |
|
81
|
|
|
result_dict = MapDict(zip(self._output_keys, results)) |
|
82
|
|
|
if self._neural_computation: |
|
83
|
|
|
result_dict = convert_to_neural_var(result_dict) |
|
84
|
|
|
for k in result_dict: |
|
85
|
|
|
if k in self._output_dim_map: |
|
86
|
|
|
result_dict[k].output_dim = self._output_dim_map[k] |
|
87
|
|
|
return result_dict, updates |