1
|
|
|
# -*- coding: utf-8 -*- |
2
|
|
|
import inspect |
3
|
|
|
import logging |
4
|
|
|
from six import wraps |
|
|
|
|
5
|
|
|
|
6
|
|
|
from picklable_itertools.extras import equizip |
7
|
|
|
import theano |
8
|
|
|
from theano import tensor, Variable |
9
|
|
|
|
10
|
|
|
from ..base import Application, application, Brick |
11
|
|
|
from ...initialization import NdarrayInitialization |
12
|
|
|
from ...utils import pack, dict_union, dict_subset, is_shared_variable |
13
|
|
|
|
14
|
|
|
logger = logging.getLogger(__name__) |
15
|
|
|
|
16
|
|
|
unknown_scan_input = """ |
|
|
|
|
17
|
|
|
|
18
|
|
|
Your function uses a non-shared variable other than those given \ |
19
|
|
|
by scan explicitly. That can significantly slow down `tensor.grad` \ |
20
|
|
|
call. Did you forget to declare it in `contexts`?""" |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
class BaseRecurrent(Brick): |
24
|
|
|
"""Base class for brick with recurrent application method.""" |
25
|
|
|
has_bias = False |
26
|
|
|
|
27
|
|
|
@application |
28
|
|
|
def initial_states(self, batch_size, *args, **kwargs): |
|
|
|
|
29
|
|
|
r"""Return initial states for an application call. |
30
|
|
|
|
31
|
|
|
Default implementation assumes that the recurrent application |
32
|
|
|
method is called `apply`. It fetches the state names |
33
|
|
|
from `apply.states` and a returns a zero matrix for each of them. |
34
|
|
|
|
35
|
|
|
:class:`SimpleRecurrent`, :class:`LSTM` and :class:`GatedRecurrent` |
36
|
|
|
override this method with trainable initial states initialized |
37
|
|
|
with zeros. |
38
|
|
|
|
39
|
|
|
Parameters |
40
|
|
|
---------- |
41
|
|
|
batch_size : int |
42
|
|
|
The batch size. |
43
|
|
|
\*args |
44
|
|
|
The positional arguments of the application call. |
45
|
|
|
\*\*kwargs |
46
|
|
|
The keyword arguments of the application call. |
47
|
|
|
|
48
|
|
|
""" |
49
|
|
|
result = [] |
50
|
|
|
for state in self.apply.states: |
|
|
|
|
51
|
|
|
dim = self.get_dim(state) |
52
|
|
|
if dim == 0: |
53
|
|
|
result.append(tensor.zeros((batch_size,))) |
54
|
|
|
else: |
55
|
|
|
result.append(tensor.zeros((batch_size, dim))) |
56
|
|
|
return result |
57
|
|
|
|
58
|
|
|
@initial_states.property('outputs') |
59
|
|
|
def initial_states_outputs(self): |
60
|
|
|
return self.apply.states |
|
|
|
|
61
|
|
|
|
62
|
|
|
|
63
|
|
|
def recurrent(*args, **kwargs): |
64
|
|
|
"""Wraps an apply method to allow its iterative application. |
65
|
|
|
|
66
|
|
|
This decorator allows you to implement only one step of a recurrent |
67
|
|
|
network and enjoy applying it to sequences for free. The idea behind is |
68
|
|
|
that its most general form information flow of an RNN can be described |
69
|
|
|
as follows: depending on the context and driven by input sequences the |
70
|
|
|
RNN updates its states and produces output sequences. |
71
|
|
|
|
72
|
|
|
Given a method describing one step of an RNN and a specification |
73
|
|
|
which of its inputs are the elements of the input sequence, |
74
|
|
|
which are the states and which are the contexts, this decorator |
75
|
|
|
returns an application method which implements the whole RNN loop. |
76
|
|
|
The returned application method also has additional parameters, |
77
|
|
|
see documentation of the `recurrent_apply` inner function below. |
78
|
|
|
|
79
|
|
|
Parameters |
80
|
|
|
---------- |
81
|
|
|
sequences : list of strs |
82
|
|
|
Specifies which of the arguments are elements of input sequences. |
83
|
|
|
states : list of strs |
84
|
|
|
Specifies which of the arguments are the states. |
85
|
|
|
contexts : list of strs |
86
|
|
|
Specifies which of the arguments are the contexts. |
87
|
|
|
outputs : list of strs |
88
|
|
|
Names of the outputs. The outputs whose names match with those |
89
|
|
|
in the `state` parameter are interpreted as next step states. |
90
|
|
|
|
91
|
|
|
Returns |
92
|
|
|
------- |
93
|
|
|
recurrent_apply : :class:`~blocks.bricks.base.Application` |
94
|
|
|
The new application method that applies the RNN to sequences. |
95
|
|
|
|
96
|
|
|
See Also |
97
|
|
|
-------- |
98
|
|
|
:doc:`The tutorial on RNNs </rnn>` |
99
|
|
|
|
100
|
|
|
""" |
101
|
|
|
def recurrent_wrapper(application_function): |
102
|
|
|
arg_spec = inspect.getargspec(application_function) |
103
|
|
|
arg_names = arg_spec.args[1:] |
104
|
|
|
|
105
|
|
|
@wraps(application_function) |
106
|
|
|
def recurrent_apply(brick, application, application_call, |
|
|
|
|
107
|
|
|
*args, **kwargs): |
108
|
|
|
"""Iterates a transition function. |
109
|
|
|
|
110
|
|
|
Parameters |
111
|
|
|
---------- |
112
|
|
|
iterate : bool |
113
|
|
|
If ``True`` iteration is made. By default ``True``. |
114
|
|
|
reverse : bool |
115
|
|
|
If ``True``, the sequences are processed in backward |
116
|
|
|
direction. ``False`` by default. |
117
|
|
|
return_initial_states : bool |
118
|
|
|
If ``True``, initial states are included in the returned |
119
|
|
|
state tensors. ``False`` by default. |
120
|
|
|
|
121
|
|
|
""" |
122
|
|
|
# Extract arguments related to iteration and immediately relay the |
123
|
|
|
# call to the wrapped function if `iterate=False` |
124
|
|
|
iterate = kwargs.pop('iterate', True) |
125
|
|
|
if not iterate: |
126
|
|
|
return application_function(brick, *args, **kwargs) |
127
|
|
|
reverse = kwargs.pop('reverse', False) |
128
|
|
|
scan_kwargs = kwargs.pop('scan_kwargs', {}) |
129
|
|
|
return_initial_states = kwargs.pop('return_initial_states', False) |
130
|
|
|
|
131
|
|
|
# Push everything to kwargs |
132
|
|
|
for arg, arg_name in zip(args, arg_names): |
133
|
|
|
kwargs[arg_name] = arg |
134
|
|
|
|
135
|
|
|
# Make sure that all arguments for scan are tensor variables |
136
|
|
|
scan_arguments = (application.sequences + application.states + |
137
|
|
|
application.contexts) |
138
|
|
|
for arg in scan_arguments: |
139
|
|
|
if arg in kwargs: |
140
|
|
|
if kwargs[arg] is None: |
141
|
|
|
del kwargs[arg] |
142
|
|
|
else: |
143
|
|
|
kwargs[arg] = tensor.as_tensor_variable(kwargs[arg]) |
144
|
|
|
|
145
|
|
|
# Check which sequence and contexts were provided |
146
|
|
|
sequences_given = dict_subset(kwargs, application.sequences, |
147
|
|
|
must_have=False) |
148
|
|
|
contexts_given = dict_subset(kwargs, application.contexts, |
149
|
|
|
must_have=False) |
150
|
|
|
|
151
|
|
|
# Determine number of steps and batch size. |
152
|
|
|
if len(sequences_given): |
153
|
|
|
# TODO Assumes 1 time dim! |
154
|
|
|
shape = list(sequences_given.values())[0].shape |
155
|
|
|
n_steps = shape[0] |
156
|
|
|
batch_size = shape[1] |
157
|
|
|
else: |
158
|
|
|
# TODO Raise error if n_steps and batch_size not found? |
159
|
|
|
n_steps = kwargs.pop('n_steps') |
160
|
|
|
batch_size = kwargs.pop('batch_size') |
161
|
|
|
|
162
|
|
|
# Handle the rest kwargs |
163
|
|
|
rest_kwargs = {key: value for key, value in kwargs.items() |
164
|
|
|
if key not in scan_arguments} |
165
|
|
|
for value in rest_kwargs.values(): |
166
|
|
|
if (isinstance(value, Variable) and not |
167
|
|
|
is_shared_variable(value)): |
168
|
|
|
logger.warning("unknown input {}".format(value) + |
169
|
|
|
unknown_scan_input) |
170
|
|
|
|
171
|
|
|
# Ensure that all initial states are available. |
172
|
|
|
initial_states = brick.initial_states(batch_size, as_dict=True, |
173
|
|
|
*args, **kwargs) |
174
|
|
|
for state_name in application.states: |
175
|
|
|
dim = brick.get_dim(state_name) |
176
|
|
|
if state_name in kwargs: |
177
|
|
|
if isinstance(kwargs[state_name], NdarrayInitialization): |
178
|
|
|
kwargs[state_name] = tensor.alloc( |
179
|
|
|
kwargs[state_name].generate(brick.rng, (1, dim)), |
180
|
|
|
batch_size, dim) |
181
|
|
|
elif isinstance(kwargs[state_name], Application): |
182
|
|
|
kwargs[state_name] = ( |
183
|
|
|
kwargs[state_name](state_name, batch_size, |
184
|
|
|
*args, **kwargs)) |
185
|
|
|
else: |
186
|
|
|
try: |
187
|
|
|
kwargs[state_name] = initial_states[state_name] |
188
|
|
|
except KeyError: |
189
|
|
|
raise KeyError( |
190
|
|
|
"no initial state for '{}' of the brick {}".format( |
191
|
|
|
state_name, brick.name)) |
192
|
|
|
states_given = dict_subset(kwargs, application.states) |
193
|
|
|
|
194
|
|
|
# Theano issue 1772 |
195
|
|
|
for name, state in states_given.items(): |
196
|
|
|
states_given[name] = tensor.unbroadcast(state, |
197
|
|
|
*range(state.ndim)) |
198
|
|
|
|
199
|
|
|
def scan_function(*args): |
200
|
|
|
args = list(args) |
201
|
|
|
arg_names = (list(sequences_given) + |
202
|
|
|
[output for output in application.outputs |
203
|
|
|
if output in application.states] + |
204
|
|
|
list(contexts_given)) |
205
|
|
|
kwargs = dict(equizip(arg_names, args)) |
206
|
|
|
kwargs.update(rest_kwargs) |
207
|
|
|
outputs = application(iterate=False, **kwargs) |
208
|
|
|
# We want to save the computation graph returned by the |
209
|
|
|
# `application_function` when it is called inside the |
210
|
|
|
# `theano.scan`. |
211
|
|
|
application_call.inner_inputs = args |
212
|
|
|
application_call.inner_outputs = pack(outputs) |
213
|
|
|
return outputs |
214
|
|
|
outputs_info = [ |
215
|
|
|
states_given[name] if name in application.states |
216
|
|
|
else None |
217
|
|
|
for name in application.outputs] |
218
|
|
|
result, updates = theano.scan( |
219
|
|
|
scan_function, sequences=list(sequences_given.values()), |
220
|
|
|
outputs_info=outputs_info, |
221
|
|
|
non_sequences=list(contexts_given.values()), |
222
|
|
|
n_steps=n_steps, |
223
|
|
|
go_backwards=reverse, |
224
|
|
|
name='{}_{}_scan'.format( |
225
|
|
|
brick.name, application.application_name), |
226
|
|
|
**scan_kwargs) |
227
|
|
|
result = pack(result) |
228
|
|
|
if return_initial_states: |
229
|
|
|
# Undo Subtensor |
230
|
|
|
for i, info in enumerate(outputs_info): |
231
|
|
|
if info is not None: |
232
|
|
|
assert isinstance(result[i].owner.op, |
233
|
|
|
tensor.subtensor.Subtensor) |
234
|
|
|
result[i] = result[i].owner.inputs[0] |
235
|
|
|
if updates: |
236
|
|
|
application_call.updates = dict_union(application_call.updates, |
237
|
|
|
updates) |
238
|
|
|
|
239
|
|
|
return result |
240
|
|
|
|
241
|
|
|
return recurrent_apply |
242
|
|
|
|
243
|
|
|
# Decorator can be used with or without arguments |
244
|
|
|
assert (args and not kwargs) or (not args and kwargs) |
245
|
|
|
if args: |
246
|
|
|
application_function, = args |
247
|
|
|
return application(recurrent_wrapper(application_function)) |
248
|
|
|
else: |
249
|
|
|
def wrap_application(application_function): |
250
|
|
|
return application(**kwargs)( |
251
|
|
|
recurrent_wrapper(application_function)) |
252
|
|
|
return wrap_application |
253
|
|
|
|