|
1
|
|
|
# -*- coding: utf-8 -*- |
|
2
|
|
|
import inspect |
|
3
|
|
|
import logging |
|
4
|
|
|
from six import wraps |
|
|
|
|
|
|
5
|
|
|
|
|
6
|
|
|
from picklable_itertools.extras import equizip |
|
7
|
|
|
import theano |
|
8
|
|
|
from theano import tensor, Variable |
|
9
|
|
|
|
|
10
|
|
|
from ..base import Application, application, Brick |
|
11
|
|
|
from ...initialization import NdarrayInitialization |
|
12
|
|
|
from ...utils import pack, dict_union, dict_subset, is_shared_variable |
|
13
|
|
|
|
|
14
|
|
|
logger = logging.getLogger(__name__) |
|
15
|
|
|
|
|
16
|
|
|
unknown_scan_input = """ |
|
|
|
|
|
|
17
|
|
|
|
|
18
|
|
|
Your function uses a non-shared variable other than those given \ |
|
19
|
|
|
by scan explicitly. That can significantly slow down `tensor.grad` \ |
|
20
|
|
|
call. Did you forget to declare it in `contexts`?""" |
|
21
|
|
|
|
|
22
|
|
|
|
|
23
|
|
|
class BaseRecurrent(Brick): |
|
24
|
|
|
"""Base class for brick with recurrent application method.""" |
|
25
|
|
|
has_bias = False |
|
26
|
|
|
|
|
27
|
|
|
@application |
|
28
|
|
|
def initial_states(self, batch_size, *args, **kwargs): |
|
|
|
|
|
|
29
|
|
|
r"""Return initial states for an application call. |
|
30
|
|
|
|
|
31
|
|
|
Default implementation assumes that the recurrent application |
|
32
|
|
|
method is called `apply`. It fetches the state names |
|
33
|
|
|
from `apply.states` and a returns a zero matrix for each of them. |
|
34
|
|
|
|
|
35
|
|
|
:class:`SimpleRecurrent`, :class:`LSTM` and :class:`GatedRecurrent` |
|
36
|
|
|
override this method with trainable initial states initialized |
|
37
|
|
|
with zeros. |
|
38
|
|
|
|
|
39
|
|
|
Parameters |
|
40
|
|
|
---------- |
|
41
|
|
|
batch_size : int |
|
42
|
|
|
The batch size. |
|
43
|
|
|
\*args |
|
44
|
|
|
The positional arguments of the application call. |
|
45
|
|
|
\*\*kwargs |
|
46
|
|
|
The keyword arguments of the application call. |
|
47
|
|
|
|
|
48
|
|
|
""" |
|
49
|
|
|
result = [] |
|
50
|
|
|
for state in self.apply.states: |
|
|
|
|
|
|
51
|
|
|
dim = self.get_dim(state) |
|
52
|
|
|
if dim == 0: |
|
53
|
|
|
result.append(tensor.zeros((batch_size,))) |
|
54
|
|
|
else: |
|
55
|
|
|
result.append(tensor.zeros((batch_size, dim))) |
|
56
|
|
|
return result |
|
57
|
|
|
|
|
58
|
|
|
@initial_states.property('outputs') |
|
59
|
|
|
def initial_states_outputs(self): |
|
60
|
|
|
return self.apply.states |
|
|
|
|
|
|
61
|
|
|
|
|
62
|
|
|
|
|
63
|
|
|
def recurrent(*args, **kwargs): |
|
64
|
|
|
"""Wraps an apply method to allow its iterative application. |
|
65
|
|
|
|
|
66
|
|
|
This decorator allows you to implement only one step of a recurrent |
|
67
|
|
|
network and enjoy applying it to sequences for free. The idea behind is |
|
68
|
|
|
that its most general form information flow of an RNN can be described |
|
69
|
|
|
as follows: depending on the context and driven by input sequences the |
|
70
|
|
|
RNN updates its states and produces output sequences. |
|
71
|
|
|
|
|
72
|
|
|
Given a method describing one step of an RNN and a specification |
|
73
|
|
|
which of its inputs are the elements of the input sequence, |
|
74
|
|
|
which are the states and which are the contexts, this decorator |
|
75
|
|
|
returns an application method which implements the whole RNN loop. |
|
76
|
|
|
The returned application method also has additional parameters, |
|
77
|
|
|
see documentation of the `recurrent_apply` inner function below. |
|
78
|
|
|
|
|
79
|
|
|
Parameters |
|
80
|
|
|
---------- |
|
81
|
|
|
sequences : list of strs |
|
82
|
|
|
Specifies which of the arguments are elements of input sequences. |
|
83
|
|
|
states : list of strs |
|
84
|
|
|
Specifies which of the arguments are the states. |
|
85
|
|
|
contexts : list of strs |
|
86
|
|
|
Specifies which of the arguments are the contexts. |
|
87
|
|
|
outputs : list of strs |
|
88
|
|
|
Names of the outputs. The outputs whose names match with those |
|
89
|
|
|
in the `state` parameter are interpreted as next step states. |
|
90
|
|
|
|
|
91
|
|
|
Returns |
|
92
|
|
|
------- |
|
93
|
|
|
recurrent_apply : :class:`~blocks.bricks.base.Application` |
|
94
|
|
|
The new application method that applies the RNN to sequences. |
|
95
|
|
|
|
|
96
|
|
|
See Also |
|
97
|
|
|
-------- |
|
98
|
|
|
:doc:`The tutorial on RNNs </rnn>` |
|
99
|
|
|
|
|
100
|
|
|
""" |
|
101
|
|
|
def recurrent_wrapper(application_function): |
|
102
|
|
|
arg_spec = inspect.getargspec(application_function) |
|
103
|
|
|
arg_names = arg_spec.args[1:] |
|
104
|
|
|
|
|
105
|
|
|
@wraps(application_function) |
|
106
|
|
|
def recurrent_apply(brick, application, application_call, |
|
|
|
|
|
|
107
|
|
|
*args, **kwargs): |
|
108
|
|
|
"""Iterates a transition function. |
|
109
|
|
|
|
|
110
|
|
|
Parameters |
|
111
|
|
|
---------- |
|
112
|
|
|
iterate : bool |
|
113
|
|
|
If ``True`` iteration is made. By default ``True``. |
|
114
|
|
|
reverse : bool |
|
115
|
|
|
If ``True``, the sequences are processed in backward |
|
116
|
|
|
direction. ``False`` by default. |
|
117
|
|
|
return_initial_states : bool |
|
118
|
|
|
If ``True``, initial states are included in the returned |
|
119
|
|
|
state tensors. ``False`` by default. |
|
120
|
|
|
|
|
121
|
|
|
""" |
|
122
|
|
|
# Extract arguments related to iteration and immediately relay the |
|
123
|
|
|
# call to the wrapped function if `iterate=False` |
|
124
|
|
|
iterate = kwargs.pop('iterate', True) |
|
125
|
|
|
if not iterate: |
|
126
|
|
|
return application_function(brick, *args, **kwargs) |
|
127
|
|
|
reverse = kwargs.pop('reverse', False) |
|
128
|
|
|
scan_kwargs = kwargs.pop('scan_kwargs', {}) |
|
129
|
|
|
return_initial_states = kwargs.pop('return_initial_states', False) |
|
130
|
|
|
|
|
131
|
|
|
# Push everything to kwargs |
|
132
|
|
|
for arg, arg_name in zip(args, arg_names): |
|
133
|
|
|
kwargs[arg_name] = arg |
|
134
|
|
|
|
|
135
|
|
|
# Make sure that all arguments for scan are tensor variables |
|
136
|
|
|
scan_arguments = (application.sequences + application.states + |
|
137
|
|
|
application.contexts) |
|
138
|
|
|
for arg in scan_arguments: |
|
139
|
|
|
if arg in kwargs: |
|
140
|
|
|
if kwargs[arg] is None: |
|
141
|
|
|
del kwargs[arg] |
|
142
|
|
|
else: |
|
143
|
|
|
kwargs[arg] = tensor.as_tensor_variable(kwargs[arg]) |
|
144
|
|
|
|
|
145
|
|
|
# Check which sequence and contexts were provided |
|
146
|
|
|
sequences_given = dict_subset(kwargs, application.sequences, |
|
147
|
|
|
must_have=False) |
|
148
|
|
|
contexts_given = dict_subset(kwargs, application.contexts, |
|
149
|
|
|
must_have=False) |
|
150
|
|
|
|
|
151
|
|
|
# Determine number of steps and batch size. |
|
152
|
|
|
if len(sequences_given): |
|
153
|
|
|
# TODO Assumes 1 time dim! |
|
154
|
|
|
shape = list(sequences_given.values())[0].shape |
|
155
|
|
|
n_steps = shape[0] |
|
156
|
|
|
batch_size = shape[1] |
|
157
|
|
|
else: |
|
158
|
|
|
# TODO Raise error if n_steps and batch_size not found? |
|
159
|
|
|
n_steps = kwargs.pop('n_steps') |
|
160
|
|
|
batch_size = kwargs.pop('batch_size') |
|
161
|
|
|
|
|
162
|
|
|
# Handle the rest kwargs |
|
163
|
|
|
rest_kwargs = {key: value for key, value in kwargs.items() |
|
164
|
|
|
if key not in scan_arguments} |
|
165
|
|
|
for value in rest_kwargs.values(): |
|
166
|
|
|
if (isinstance(value, Variable) and not |
|
167
|
|
|
is_shared_variable(value)): |
|
168
|
|
|
logger.warning("unknown input {}".format(value) + |
|
169
|
|
|
unknown_scan_input) |
|
170
|
|
|
|
|
171
|
|
|
# Ensure that all initial states are available. |
|
172
|
|
|
initial_states = brick.initial_states(batch_size, as_dict=True, |
|
173
|
|
|
*args, **kwargs) |
|
174
|
|
|
for state_name in application.states: |
|
175
|
|
|
dim = brick.get_dim(state_name) |
|
176
|
|
|
if state_name in kwargs: |
|
177
|
|
|
if isinstance(kwargs[state_name], NdarrayInitialization): |
|
178
|
|
|
kwargs[state_name] = tensor.alloc( |
|
179
|
|
|
kwargs[state_name].generate(brick.rng, (1, dim)), |
|
180
|
|
|
batch_size, dim) |
|
181
|
|
|
elif isinstance(kwargs[state_name], Application): |
|
182
|
|
|
kwargs[state_name] = ( |
|
183
|
|
|
kwargs[state_name](state_name, batch_size, |
|
184
|
|
|
*args, **kwargs)) |
|
185
|
|
|
else: |
|
186
|
|
|
try: |
|
187
|
|
|
kwargs[state_name] = initial_states[state_name] |
|
188
|
|
|
except KeyError: |
|
189
|
|
|
raise KeyError( |
|
190
|
|
|
"no initial state for '{}' of the brick {}".format( |
|
191
|
|
|
state_name, brick.name)) |
|
192
|
|
|
states_given = dict_subset(kwargs, application.states) |
|
193
|
|
|
|
|
194
|
|
|
# Theano issue 1772 |
|
195
|
|
|
for name, state in states_given.items(): |
|
196
|
|
|
states_given[name] = tensor.unbroadcast(state, |
|
197
|
|
|
*range(state.ndim)) |
|
198
|
|
|
|
|
199
|
|
|
def scan_function(*args): |
|
200
|
|
|
args = list(args) |
|
201
|
|
|
arg_names = (list(sequences_given) + |
|
202
|
|
|
[output for output in application.outputs |
|
203
|
|
|
if output in application.states] + |
|
204
|
|
|
list(contexts_given)) |
|
205
|
|
|
kwargs = dict(equizip(arg_names, args)) |
|
206
|
|
|
kwargs.update(rest_kwargs) |
|
207
|
|
|
outputs = application(iterate=False, **kwargs) |
|
208
|
|
|
# We want to save the computation graph returned by the |
|
209
|
|
|
# `application_function` when it is called inside the |
|
210
|
|
|
# `theano.scan`. |
|
211
|
|
|
application_call.inner_inputs = args |
|
212
|
|
|
application_call.inner_outputs = pack(outputs) |
|
213
|
|
|
return outputs |
|
214
|
|
|
outputs_info = [ |
|
215
|
|
|
states_given[name] if name in application.states |
|
216
|
|
|
else None |
|
217
|
|
|
for name in application.outputs] |
|
218
|
|
|
result, updates = theano.scan( |
|
219
|
|
|
scan_function, sequences=list(sequences_given.values()), |
|
220
|
|
|
outputs_info=outputs_info, |
|
221
|
|
|
non_sequences=list(contexts_given.values()), |
|
222
|
|
|
n_steps=n_steps, |
|
223
|
|
|
go_backwards=reverse, |
|
224
|
|
|
name='{}_{}_scan'.format( |
|
225
|
|
|
brick.name, application.application_name), |
|
226
|
|
|
**scan_kwargs) |
|
227
|
|
|
result = pack(result) |
|
228
|
|
|
if return_initial_states: |
|
229
|
|
|
# Undo Subtensor |
|
230
|
|
|
for i, info in enumerate(outputs_info): |
|
231
|
|
|
if info is not None: |
|
232
|
|
|
assert isinstance(result[i].owner.op, |
|
233
|
|
|
tensor.subtensor.Subtensor) |
|
234
|
|
|
result[i] = result[i].owner.inputs[0] |
|
235
|
|
|
if updates: |
|
236
|
|
|
application_call.updates = dict_union(application_call.updates, |
|
237
|
|
|
updates) |
|
238
|
|
|
|
|
239
|
|
|
return result |
|
240
|
|
|
|
|
241
|
|
|
return recurrent_apply |
|
242
|
|
|
|
|
243
|
|
|
# Decorator can be used with or without arguments |
|
244
|
|
|
assert (args and not kwargs) or (not args and kwargs) |
|
245
|
|
|
if args: |
|
246
|
|
|
application_function, = args |
|
247
|
|
|
return application(recurrent_wrapper(application_function)) |
|
248
|
|
|
else: |
|
249
|
|
|
def wrap_application(application_function): |
|
250
|
|
|
return application(**kwargs)( |
|
251
|
|
|
recurrent_wrapper(application_function)) |
|
252
|
|
|
return wrap_application |
|
253
|
|
|
|