1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
import theano |
5
|
|
|
from map_dict import MapDict |
6
|
|
|
|
7
|
|
|
class Scanner(object): |
8
|
|
|
""" |
9
|
|
|
Call `theano.scan` with dictionary parameters. |
10
|
|
|
""" |
11
|
|
|
|
12
|
|
|
def __init__(self, func, sequences=None, outputs_info=None, non_sequences=None, neural_computation=False, **kwargs): |
13
|
|
|
if ((sequences and type(sequences) != dict) or |
14
|
|
|
(outputs_info and type(outputs_info) != dict) or |
15
|
|
|
(non_sequences and type(non_sequences) != dict)): |
16
|
|
|
raise Exception("The parameter `sequences`, `outputs_info`, `non_sequences` must be dict.") |
17
|
|
|
self._func = func |
18
|
|
|
|
19
|
|
|
self._sequence_keys = sequences.keys() if sequences else [] |
20
|
|
|
self._sequence_values = sequences.values() if sequences else [] |
21
|
|
|
self._output_keys = outputs_info.keys() if outputs_info else [] |
22
|
|
|
self._output_values = outputs_info.values() if outputs_info else [] |
23
|
|
|
self._non_sequence_keys = non_sequences.keys() if non_sequences else [] |
24
|
|
|
self._non_sequence_values = non_sequences.values() if non_sequences else [] |
25
|
|
|
self._kwargs = kwargs |
26
|
|
|
self._neural_computation = neural_computation |
27
|
|
|
self._input_dim_list = [] |
28
|
|
|
self._output_dim_map = {} |
29
|
|
|
|
30
|
|
|
def _func_wrapper(self, *vars): |
31
|
|
|
from deepy.core.tensor_conversion import convert_to_theano_var, convert_to_neural_var |
32
|
|
|
all_values = self._sequence_values + self._output_values + self._non_sequence_values |
33
|
|
|
all_keys = self._sequence_keys + self._output_keys + self._non_sequence_keys |
34
|
|
|
valid_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is not None] |
35
|
|
|
none_keys = [all_keys[i] for i in range(len(all_keys)) if all_values[i] is None] |
36
|
|
|
if self._neural_computation: |
37
|
|
|
for var, last_dim in zip(vars, self._input_dim_list): |
38
|
|
|
var.tag.last_dim = last_dim |
39
|
|
|
|
40
|
|
|
dict_param = MapDict(zip(valid_keys, vars)) |
41
|
|
|
dict_param.update(MapDict(zip(none_keys, [None for _ in range(len(none_keys))]))) |
42
|
|
|
if self._neural_computation: |
43
|
|
|
dict_param = convert_to_neural_var(dict_param) |
44
|
|
|
retval = self._func(dict_param) |
45
|
|
|
if type(retval) == tuple: |
46
|
|
|
dict_retval, updates = retval |
47
|
|
|
else: |
48
|
|
|
dict_retval, updates = retval, None |
49
|
|
|
if self._neural_computation: |
50
|
|
|
if isinstance(dict_retval, dict): |
51
|
|
|
for k, var in dict_retval.items(): |
52
|
|
|
self._output_dim_map[k] = var.dim() |
53
|
|
|
updates, _, _ = convert_to_theano_var(updates) |
54
|
|
|
dict_retval, _, _ = convert_to_theano_var(dict_retval) |
55
|
|
|
if type(dict_retval) == MapDict: |
56
|
|
|
dict_retval = dict(dict_retval.items()) |
57
|
|
|
if type(dict_retval) != dict: |
58
|
|
|
raise Exception("The return value of scanner function must be a dict") |
59
|
|
|
final_retval = [dict_retval[k] for k in self._output_keys] |
60
|
|
|
if len(final_retval) == 1: |
61
|
|
|
final_retval = final_retval[0] |
62
|
|
|
if updates: |
63
|
|
|
return final_retval, updates |
64
|
|
|
else: |
65
|
|
|
return final_retval |
66
|
|
|
|
67
|
|
|
def compute(self): |
68
|
|
|
from deepy.core.tensor_conversion import convert_to_theano_var, convert_to_neural_var |
69
|
|
|
if self._neural_computation: |
70
|
|
|
self._input_dim_list = [] |
71
|
|
|
for tensor in sum([self._sequence_values, self._output_values, self._non_sequence_values], []): |
72
|
|
|
last_dim = tensor.tag.last_dim if tensor and hasattr(tensor.tag, 'last_dim') else None |
73
|
|
|
self._input_dim_list.append(last_dim) |
74
|
|
|
results, updates = theano.scan(self._func_wrapper, |
75
|
|
|
sequences=filter(lambda t: t is not None, self._sequence_values), |
76
|
|
|
outputs_info=self._output_values, |
77
|
|
|
non_sequences=filter(lambda t: t is not None, self._non_sequence_values), |
78
|
|
|
**self._kwargs) |
79
|
|
|
if type(results) != list: |
80
|
|
|
results = [results] |
81
|
|
|
result_dict = MapDict(zip(self._output_keys, results)) |
82
|
|
|
if self._neural_computation: |
83
|
|
|
result_dict = convert_to_neural_var(result_dict) |
84
|
|
|
for k in result_dict: |
85
|
|
|
if k in self._output_dim_map: |
86
|
|
|
result_dict[k].output_dim = self._output_dim_map[k] |
87
|
|
|
return result_dict, updates |