1
|
|
|
# -*- coding: utf-8 -*- |
2
|
|
|
import copy |
3
|
|
|
|
4
|
|
|
from picklable_itertools.extras import equizip |
5
|
|
|
from theano import tensor |
6
|
|
|
|
7
|
|
|
from ..base import application, lazy |
8
|
|
|
from ..parallel import Fork |
9
|
|
|
from ..simple import Initializable, Linear |
10
|
|
|
from .base import BaseRecurrent, recurrent |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
class Bidirectional(Initializable): |
14
|
|
|
"""Bidirectional network. |
15
|
|
|
|
16
|
|
|
A bidirectional network is a combination of forward and backward |
17
|
|
|
recurrent networks which process inputs in different order. |
18
|
|
|
|
19
|
|
|
Parameters |
20
|
|
|
---------- |
21
|
|
|
prototype : instance of :class:`BaseRecurrent` |
22
|
|
|
A prototype brick from which the forward and backward bricks are |
23
|
|
|
cloned. |
24
|
|
|
|
25
|
|
|
Notes |
26
|
|
|
----- |
27
|
|
|
See :class:`.Initializable` for initialization parameters. |
28
|
|
|
|
29
|
|
|
""" |
30
|
|
|
has_bias = False |
31
|
|
|
|
32
|
|
|
@lazy() |
33
|
|
|
def __init__(self, prototype, **kwargs): |
34
|
|
|
self.prototype = prototype |
35
|
|
|
|
36
|
|
|
children = [copy.deepcopy(prototype) for _ in range(2)] |
37
|
|
|
children[0].name = 'forward' |
38
|
|
|
children[1].name = 'backward' |
39
|
|
|
kwargs.setdefault('children', []).extend(children) |
40
|
|
|
super(Bidirectional, self).__init__(**kwargs) |
41
|
|
|
|
42
|
|
|
@application |
43
|
|
|
def apply(self, *args, **kwargs): |
44
|
|
|
"""Applies forward and backward networks and concatenates outputs.""" |
45
|
|
|
forward = self.children[0].apply(as_list=True, *args, **kwargs) |
46
|
|
|
backward = [x[::-1] for x in |
47
|
|
|
self.children[1].apply(reverse=True, as_list=True, |
48
|
|
|
*args, **kwargs)] |
49
|
|
|
return [tensor.concatenate([f, b], axis=2) |
50
|
|
|
for f, b in equizip(forward, backward)] |
51
|
|
|
|
52
|
|
|
@apply.delegate |
53
|
|
|
def apply_delegate(self): |
54
|
|
|
return self.children[0].apply |
55
|
|
|
|
56
|
|
|
def get_dim(self, name): |
57
|
|
|
if name in self.apply.outputs: |
58
|
|
|
return self.prototype.get_dim(name) * 2 |
59
|
|
|
return self.prototype.get_dim(name) |
60
|
|
|
|
61
|
|
|
RECURRENTSTACK_SEPARATOR = '#' |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
class RecurrentStack(BaseRecurrent, Initializable): |
65
|
|
|
u"""Stack of recurrent networks. |
66
|
|
|
|
67
|
|
|
Builds a stack of recurrent layers from a supplied list of |
68
|
|
|
:class:`~blocks.bricks.recurrent.BaseRecurrent` objects. |
69
|
|
|
Each object must have a `sequences`, |
70
|
|
|
`contexts`, `states` and `outputs` parameters to its `apply` method, |
71
|
|
|
such as the ones required by the recurrent decorator from |
72
|
|
|
:mod:`blocks.bricks.recurrent`. |
73
|
|
|
|
74
|
|
|
In Blocks in general each brick can have an apply method and this |
75
|
|
|
method has attributes that list the names of the arguments that can be |
76
|
|
|
passed to the method and the name of the outputs returned by the |
77
|
|
|
method. |
78
|
|
|
The attributes of the apply method of this class is made from |
79
|
|
|
concatenating the attributes of the apply methods of each of the |
80
|
|
|
transitions from which the stack is made. |
81
|
|
|
In order to avoid conflict, the names of the arguments appearing in |
82
|
|
|
the `states` and `outputs` attributes of the apply method of each |
83
|
|
|
layers are renamed. The names of the bottom layer are used as-is and |
84
|
|
|
a suffix of the form '#<n>' is added to the names from other layers, |
85
|
|
|
where '<n>' is the number of the layer starting from 1, used for first |
86
|
|
|
layer above bottom. |
87
|
|
|
|
88
|
|
|
The `contexts` of all layers are merged into a single list of unique |
89
|
|
|
names, and no suffix is added. Different layers with the same context |
90
|
|
|
name will receive the same value. |
91
|
|
|
|
92
|
|
|
The names that appear in `sequences` are treated in the same way as |
93
|
|
|
the names of `states` and `outputs` if `skip_connections` is "True". |
94
|
|
|
The only exception is the "mask" element that may appear in the |
95
|
|
|
`sequences` attribute of all layers, no suffix is added to it and |
96
|
|
|
all layers will receive the same mask value. |
97
|
|
|
If you set `skip_connections` to False then only the arguments of the |
98
|
|
|
`sequences` from the bottom layer will appear in the `sequences` |
99
|
|
|
attribute of the apply method of this class. |
100
|
|
|
When using this class, with `skip_connections` set to "True", you can |
101
|
|
|
supply all inputs to all layers using a single fork which is created |
102
|
|
|
with `output_names` set to the `apply.sequences` attribute of this |
103
|
|
|
class. For example, :class:`~blocks.brick.SequenceGenerator` will |
104
|
|
|
create a such a fork. |
105
|
|
|
|
106
|
|
|
Whether or not `skip_connections` is set, each layer above the bottom |
107
|
|
|
also receives an input (values to its `sequences` arguments) from a |
108
|
|
|
fork of the state of the layer below it. Not to be confused with the |
109
|
|
|
external fork discussed in the previous paragraph. |
110
|
|
|
It is assumed that all `states` attributes have a "states" argument |
111
|
|
|
name (this can be configured with `states_name` parameter.) |
112
|
|
|
The output argument with this name is forked and then added to all the |
113
|
|
|
elements appearing in the `sequences` of the next layer (except for |
114
|
|
|
"mask".) |
115
|
|
|
If `skip_connections` is False then this fork has a bias by default. |
116
|
|
|
This allows direct usage of this class with input supplied only to the |
117
|
|
|
first layer. But if you do supply inputs to all layers (by setting |
118
|
|
|
`skip_connections` to "True") then by default there is no bias and the |
119
|
|
|
external fork you use to supply the inputs should have its own separate |
120
|
|
|
bias. |
121
|
|
|
|
122
|
|
|
Parameters |
123
|
|
|
---------- |
124
|
|
|
transitions : list |
125
|
|
|
List of recurrent units to use in each layer. Each derived from |
126
|
|
|
:class:`~blocks.bricks.recurrent.BaseRecurrent` |
127
|
|
|
Note: A suffix with layer number is added to transitions' names. |
128
|
|
|
fork_prototype : :class:`~blocks.bricks.FeedForward`, optional |
129
|
|
|
A prototype for the transformation applied to states_name from |
130
|
|
|
the states of each layer. The transformation is used when the |
131
|
|
|
`states_name` argument from the `outputs` of one layer |
132
|
|
|
is used as input to the `sequences` of the next layer. By default |
133
|
|
|
it :class:`~blocks.bricks.Linear` transformation is used, with |
134
|
|
|
bias if skip_connections is "False". If you supply your own |
135
|
|
|
prototype you have to enable/disable bias depending on the |
136
|
|
|
value of `skip_connections`. |
137
|
|
|
states_name : string |
138
|
|
|
In a stack of RNN the state of each layer is used as input to the |
139
|
|
|
next. The `states_name` identify the argument of the `states` |
140
|
|
|
and `outputs` attributes of |
141
|
|
|
each layer that should be used for this task. By default the |
142
|
|
|
argument is called "states". To be more precise, this is the name |
143
|
|
|
of the argument in the `outputs` attribute of the apply method of |
144
|
|
|
each transition (layer.) It is used, via fork, as the `sequences` |
145
|
|
|
(input) of the next layer. The same element should also appear |
146
|
|
|
in the `states` attribute of the apply method. |
147
|
|
|
skip_connections : bool |
148
|
|
|
By default False. When true, the `sequences` of all layers are |
149
|
|
|
add to the `sequences` of the apply of this class. When false |
150
|
|
|
only the `sequences` of the bottom layer appear in the `sequences` |
151
|
|
|
of the apply of this class. In this case the default fork |
152
|
|
|
used internally between layers has a bias (see fork_prototype.) |
153
|
|
|
An external code can inspect the `sequences` attribute of the |
154
|
|
|
apply method of this class to decide which arguments it need |
155
|
|
|
(and in what order.) With `skip_connections` you can control |
156
|
|
|
what is exposed to the externl code. If it is false then the |
157
|
|
|
external code is expected to supply inputs only to the bottom |
158
|
|
|
layer and if it is true then the external code is expected to |
159
|
|
|
supply inputs to all layers. There is just one small problem, |
160
|
|
|
the external inputs to the layers above the bottom layer are |
161
|
|
|
added to a fork of the state of the layer below it. As a result |
162
|
|
|
the output of two forks is added together and it will be |
163
|
|
|
problematic if both will have a bias. It is assumed |
164
|
|
|
that the external fork has a bias and therefore by default |
165
|
|
|
the internal fork will not have a bias if `skip_connections` |
166
|
|
|
is true. |
167
|
|
|
|
168
|
|
|
Notes |
169
|
|
|
----- |
170
|
|
|
See :class:`.BaseRecurrent` for more initialization parameters. |
171
|
|
|
|
172
|
|
|
""" |
173
|
|
|
@staticmethod |
174
|
|
|
def suffix(name, level): |
175
|
|
|
if name == "mask": |
176
|
|
|
return "mask" |
177
|
|
|
if level == 0: |
178
|
|
|
return name |
179
|
|
|
return name + RECURRENTSTACK_SEPARATOR + str(level) |
180
|
|
|
|
181
|
|
|
@staticmethod |
182
|
|
|
def suffixes(names, level): |
183
|
|
|
return [RecurrentStack.suffix(name, level) |
184
|
|
|
for name in names if name != "mask"] |
185
|
|
|
|
186
|
|
|
@staticmethod |
187
|
|
|
def split_suffix(name): |
188
|
|
|
# Target name with suffix to the correct layer |
189
|
|
|
name_level = name.rsplit(RECURRENTSTACK_SEPARATOR, 1) |
190
|
|
|
if len(name_level) == 2 and name_level[-1].isdigit(): |
191
|
|
|
name = name_level[0] |
192
|
|
|
level = int(name_level[-1]) |
193
|
|
|
else: |
194
|
|
|
# It must be from bottom layer |
195
|
|
|
level = 0 |
196
|
|
|
return name, level |
197
|
|
|
|
198
|
|
|
def __init__(self, transitions, fork_prototype=None, states_name="states", |
199
|
|
|
skip_connections=False, **kwargs): |
200
|
|
|
super(RecurrentStack, self).__init__(**kwargs) |
201
|
|
|
|
202
|
|
|
self.states_name = states_name |
203
|
|
|
self.skip_connections = skip_connections |
204
|
|
|
|
205
|
|
|
for level, transition in enumerate(transitions): |
206
|
|
|
transition.name += RECURRENTSTACK_SEPARATOR + str(level) |
207
|
|
|
self.transitions = transitions |
208
|
|
|
|
209
|
|
|
if fork_prototype is None: |
210
|
|
|
# If we are not supplied any inputs for the layers above |
211
|
|
|
# bottom then use bias |
212
|
|
|
fork_prototype = Linear(use_bias=not skip_connections) |
213
|
|
|
depth = len(transitions) |
214
|
|
|
self.forks = [Fork(self.normal_inputs(level), |
215
|
|
|
name='fork_' + str(level), |
216
|
|
|
prototype=fork_prototype) |
217
|
|
|
for level in range(1, depth)] |
218
|
|
|
|
219
|
|
|
self.children = self.transitions + self.forks |
220
|
|
|
|
221
|
|
|
# Programmatically set the apply parameters. |
222
|
|
|
# parameters of base level are exposed as is |
223
|
|
|
# excpet for mask which we will put at the very end. See below. |
224
|
|
|
for property_ in ["sequences", "states", "outputs"]: |
225
|
|
|
setattr(self.apply, |
226
|
|
|
property_, |
227
|
|
|
self.suffixes(getattr(transitions[0].apply, property_), 0) |
228
|
|
|
) |
229
|
|
|
|
230
|
|
|
# add parameters of other layers |
231
|
|
|
if skip_connections: |
232
|
|
|
exposed_arguments = ["sequences", "states", "outputs"] |
233
|
|
|
else: |
234
|
|
|
exposed_arguments = ["states", "outputs"] |
235
|
|
|
for level in range(1, depth): |
236
|
|
|
for property_ in exposed_arguments: |
237
|
|
|
setattr(self.apply, |
238
|
|
|
property_, |
239
|
|
|
getattr(self.apply, property_) + |
240
|
|
|
self.suffixes(getattr(transitions[level].apply, |
241
|
|
|
property_), |
242
|
|
|
level) |
243
|
|
|
) |
244
|
|
|
|
245
|
|
|
# place mask at end because it has a default value (None) |
246
|
|
|
# and therefor should come after arguments that may come us |
247
|
|
|
# unnamed arguments |
248
|
|
|
if "mask" in transitions[0].apply.sequences: |
249
|
|
|
self.apply.sequences.append("mask") |
250
|
|
|
|
251
|
|
|
# add context |
252
|
|
|
self.apply.contexts = list(set( |
253
|
|
|
sum([transition.apply.contexts for transition in transitions], []) |
254
|
|
|
)) |
255
|
|
|
|
256
|
|
|
# sum up all the arguments we expect to see in a call to a transition |
257
|
|
|
# apply method, anything else is a recursion control |
258
|
|
|
self.transition_args = set(self.apply.sequences + |
259
|
|
|
self.apply.states + |
260
|
|
|
self.apply.contexts) |
261
|
|
|
|
262
|
|
|
for property_ in ["sequences", "states", "contexts", "outputs"]: |
263
|
|
|
setattr(self.low_memory_apply, property_, |
264
|
|
|
getattr(self.apply, property_)) |
265
|
|
|
|
266
|
|
|
self.initial_states.outputs = self.apply.states |
267
|
|
|
|
268
|
|
|
def normal_inputs(self, level): |
269
|
|
|
return [name for name in self.transitions[level].apply.sequences |
270
|
|
|
if name != 'mask'] |
271
|
|
|
|
272
|
|
|
def _push_allocation_config(self): |
273
|
|
|
# Configure the forks that connect the "states" element in the `states` |
274
|
|
|
# of one layer to the elements in the `sequences` of the next layer, |
275
|
|
|
# excluding "mask". |
276
|
|
|
# This involves `get_dim` requests |
277
|
|
|
# to the transitions. To make sure that it answers |
278
|
|
|
# correctly we should finish its configuration first. |
279
|
|
|
for transition in self.transitions: |
280
|
|
|
transition.push_allocation_config() |
281
|
|
|
|
282
|
|
|
for level, fork in enumerate(self.forks): |
283
|
|
|
fork.input_dim = self.transitions[level].get_dim(self.states_name) |
284
|
|
|
fork.output_dims = self.transitions[level + 1].get_dims( |
285
|
|
|
fork.output_names) |
286
|
|
|
|
287
|
|
|
def do_apply(self, *args, **kwargs): |
288
|
|
|
"""Apply the stack of transitions. |
289
|
|
|
|
290
|
|
|
This is the undecorated implementation of the apply method. |
291
|
|
|
A method with an @apply decoration should call this method with |
292
|
|
|
`iterate=True` to indicate that the iteration over all steps |
293
|
|
|
should be done internally by this method. A method with a |
294
|
|
|
`@recurrent` method should have `iterate=False` (or unset) to |
295
|
|
|
indicate that the iteration over all steps is done externally. |
296
|
|
|
|
297
|
|
|
""" |
298
|
|
|
nargs = len(args) |
299
|
|
|
args_names = self.apply.sequences + self.apply.contexts |
300
|
|
|
assert nargs <= len(args_names) |
301
|
|
|
kwargs.update(zip(args_names[:nargs], args)) |
302
|
|
|
|
303
|
|
|
if kwargs.get("reverse", False): |
304
|
|
|
raise NotImplementedError |
305
|
|
|
|
306
|
|
|
results = [] |
307
|
|
|
last_states = None |
308
|
|
|
for level, transition in enumerate(self.transitions): |
309
|
|
|
normal_inputs = self.normal_inputs(level) |
310
|
|
|
layer_kwargs = dict() |
311
|
|
|
|
312
|
|
|
if level == 0 or self.skip_connections: |
313
|
|
|
for name in normal_inputs: |
314
|
|
|
layer_kwargs[name] = kwargs.get(self.suffix(name, level)) |
315
|
|
|
if "mask" in transition.apply.sequences: |
316
|
|
|
layer_kwargs["mask"] = kwargs.get("mask") |
317
|
|
|
|
318
|
|
|
for name in transition.apply.states: |
319
|
|
|
layer_kwargs[name] = kwargs.get(self.suffix(name, level)) |
320
|
|
|
|
321
|
|
|
for name in transition.apply.contexts: |
322
|
|
|
layer_kwargs[name] = kwargs.get(name) # contexts has no suffix |
323
|
|
|
|
324
|
|
|
if level > 0: |
325
|
|
|
# add the forked states of the layer below |
326
|
|
|
inputs = self.forks[level - 1].apply(last_states, as_list=True) |
327
|
|
|
for name, input_ in zip(normal_inputs, inputs): |
328
|
|
|
if layer_kwargs.get(name): |
329
|
|
|
layer_kwargs[name] += input_ |
330
|
|
|
else: |
331
|
|
|
layer_kwargs[name] = input_ |
332
|
|
|
|
333
|
|
|
# Handle all other arguments |
334
|
|
|
# For example, if the method is called directly |
335
|
|
|
# (`low_memory=False`) |
336
|
|
|
# then the arguments that recurrent |
337
|
|
|
# expects to see such as: 'iterate', 'reverse', |
338
|
|
|
# 'return_initial_states' may appear. |
339
|
|
|
for k in set(kwargs.keys()) - self.transition_args: |
340
|
|
|
layer_kwargs[k] = kwargs[k] |
341
|
|
|
|
342
|
|
|
result = transition.apply(as_list=True, **layer_kwargs) |
343
|
|
|
results.extend(result) |
344
|
|
|
|
345
|
|
|
state_index = transition.apply.outputs.index(self.states_name) |
346
|
|
|
last_states = result[state_index] |
347
|
|
|
if kwargs.get('return_initial_states', False): |
348
|
|
|
# Note that the following line reset the tag |
349
|
|
|
last_states = last_states[1:] |
350
|
|
|
|
351
|
|
|
return tuple(results) |
352
|
|
|
|
353
|
|
|
@recurrent |
354
|
|
|
def low_memory_apply(self, *args, **kwargs): |
355
|
|
|
# we let the recurrent decorator handle the iteration for us |
356
|
|
|
# so do_apply needs to do a single step. |
357
|
|
|
kwargs['iterate'] = False |
358
|
|
|
return self.do_apply(*args, **kwargs) |
359
|
|
|
|
360
|
|
|
@application |
361
|
|
|
def apply(self, *args, **kwargs): |
362
|
|
|
r"""Apply the stack of transitions. |
363
|
|
|
|
364
|
|
|
Parameters |
365
|
|
|
---------- |
366
|
|
|
low_memory : bool |
367
|
|
|
Use the slow, but also memory efficient, implementation of |
368
|
|
|
this code. |
369
|
|
|
\*args : :class:`~tensor.TensorVariable`, optional |
370
|
|
|
Positional argumentes in the order in which they appear in |
371
|
|
|
`self.apply.sequences` followed by `self.apply.contexts`. |
372
|
|
|
\*\*kwargs : :class:`~tensor.TensorVariable` |
373
|
|
|
Named argument defined in `self.apply.sequences`, |
374
|
|
|
`self.apply.states` or `self.apply.contexts` |
375
|
|
|
|
376
|
|
|
Returns |
377
|
|
|
------- |
378
|
|
|
outputs : (list of) :class:`~tensor.TensorVariable` |
379
|
|
|
The outputs of all transitions as defined in |
380
|
|
|
`self.apply.outputs` |
381
|
|
|
|
382
|
|
|
See Also |
383
|
|
|
-------- |
384
|
|
|
See docstring of this class for arguments appearing in the lists |
385
|
|
|
`self.apply.sequences`, `self.apply.states`, `self.apply.contexts`. |
386
|
|
|
See :func:`~blocks.brick.recurrent.recurrent` : for all other |
387
|
|
|
parameters such as `iterate` and `return_initial_states` however |
388
|
|
|
`reverse` is currently not implemented. |
389
|
|
|
|
390
|
|
|
""" |
391
|
|
|
if kwargs.pop('low_memory', False): |
392
|
|
|
return self.low_memory_apply(*args, **kwargs) |
393
|
|
|
# we let the transition in self.transitions each do their iterations |
394
|
|
|
# separatly, one layer at a time. |
395
|
|
|
return self.do_apply(*args, **kwargs) |
396
|
|
|
|
397
|
|
|
def get_dim(self, name): |
398
|
|
|
# Check if we have a contexts element. |
399
|
|
|
for transition in self.transitions: |
400
|
|
|
if name in transition.apply.contexts: |
401
|
|
|
# hopefully there is no conflict between layers about dim |
402
|
|
|
return transition.get_dim(name) |
403
|
|
|
|
404
|
|
|
name, level = self.split_suffix(name) |
405
|
|
|
transition = self.transitions[level] |
406
|
|
|
return transition.get_dim(name) |
407
|
|
|
|
408
|
|
|
@application |
409
|
|
|
def initial_states(self, batch_size, *args, **kwargs): |
410
|
|
|
results = [] |
411
|
|
|
for transition in self.transitions: |
412
|
|
|
results += transition.initial_states(batch_size, *args, |
413
|
|
|
as_list=True, **kwargs) |
414
|
|
|
return results |
415
|
|
|
|