|
1
|
|
|
"""Attention mechanisms. |
|
2
|
|
|
|
|
3
|
|
|
This module defines the interface of attention mechanisms and a few |
|
4
|
|
|
concrete implementations. For a gentle introduction and usage examples see |
|
5
|
|
|
the tutorial TODO. |
|
6
|
|
|
|
|
7
|
|
|
An attention mechanism decides to what part of the input to pay attention. |
|
8
|
|
|
It is typically used as a component of a recurrent network, though one can |
|
9
|
|
|
imagine it used in other conditions as well. When the input is big and has |
|
10
|
|
|
certain structure, for instance when it is sequence or an image, an |
|
11
|
|
|
attention mechanism can be applied to extract only information which is |
|
12
|
|
|
relevant for the network in its current state. |
|
13
|
|
|
|
|
14
|
|
|
For the purpose of documentation clarity, we fix the following terminology |
|
15
|
|
|
in this file: |
|
16
|
|
|
|
|
17
|
|
|
* *network* is the network, typically a recurrent one, which |
|
18
|
|
|
uses the attention mechanism. |
|
19
|
|
|
|
|
20
|
|
|
* The network has *states*. Using this word in plural might seem weird, but |
|
21
|
|
|
some recurrent networks like :class:`~blocks.bricks.recurrent.LSTM` do |
|
22
|
|
|
have several states. |
|
23
|
|
|
|
|
24
|
|
|
* The big structured input, to which the attention mechanism is applied, |
|
25
|
|
|
is called the *attended*. When it has variable structure, e.g. a sequence |
|
26
|
|
|
of variable length, there might be a *mask* associated with it. |
|
27
|
|
|
|
|
28
|
|
|
* The information extracted by the attention from the attended is called |
|
29
|
|
|
*glimpse*, more specifically *glimpses* because there might be a few |
|
30
|
|
|
pieces of this information. |
|
31
|
|
|
|
|
32
|
|
|
Using this terminology, the attention mechanism computes glimpses |
|
33
|
|
|
given the states of the network and the attended. |
|
34
|
|
|
|
|
35
|
|
|
An example: in the machine translation network from [BCB]_ the attended is |
|
36
|
|
|
a sequence of so-called annotations, that is states of a bidirectional |
|
37
|
|
|
network that was driven by word embeddings of the source sentence. The |
|
38
|
|
|
attention mechanism assigns weights to the annotations. The weighted sum of |
|
39
|
|
|
the annotations is further used by the translation network to predict the |
|
40
|
|
|
next word of the generated translation. The weights and the weighted sum |
|
41
|
|
|
are the glimpses. A generalized attention mechanism for this paper is |
|
42
|
|
|
represented here as :class:`SequenceContentAttention`. |
|
43
|
|
|
|
|
44
|
|
|
""" |
|
45
|
|
|
from abc import ABCMeta, abstractmethod |
|
46
|
|
|
|
|
47
|
|
|
from theano import tensor |
|
48
|
|
|
from six import add_metaclass |
|
49
|
|
|
|
|
50
|
|
|
from blocks.bricks import (Brick, Initializable, Sequence, |
|
51
|
|
|
Feedforward, Linear, Tanh) |
|
52
|
|
|
from blocks.bricks.base import lazy, application |
|
53
|
|
|
from blocks.bricks.parallel import Parallel, Distribute |
|
54
|
|
|
from blocks.bricks.recurrent import recurrent, BaseRecurrent |
|
55
|
|
|
from blocks.utils import dict_union, dict_subset, pack |
|
56
|
|
|
|
|
57
|
|
|
|
|
58
|
|
|
class AbstractAttention(Brick): |
|
59
|
|
|
"""The common interface for attention bricks. |
|
60
|
|
|
|
|
61
|
|
|
First, see the module-level docstring for terminology. |
|
62
|
|
|
|
|
63
|
|
|
A generic attention mechanism functions as follows. Its inputs are the |
|
64
|
|
|
states of the network and the attended. Given these two it produces |
|
65
|
|
|
so-called *glimpses*, that is it extracts information from the attended |
|
66
|
|
|
which is necessary for the network in its current states |
|
67
|
|
|
|
|
68
|
|
|
For computational reasons we separate the process described above into |
|
69
|
|
|
two stages: |
|
70
|
|
|
|
|
71
|
|
|
1. The preprocessing stage, :meth:`preprocess`, includes computation |
|
72
|
|
|
that do not involve the state. Those can be often performed in advance. |
|
73
|
|
|
The outcome of this stage is called *preprocessed_attended*. |
|
74
|
|
|
|
|
75
|
|
|
2. The main stage, :meth:`take_glimpses`, includes all the rest. |
|
76
|
|
|
|
|
77
|
|
|
When an attention mechanism is applied sequentially, some glimpses from |
|
78
|
|
|
the previous step might be necessary to compute the new ones. A |
|
79
|
|
|
typical example for that is when the focus position from the previous |
|
80
|
|
|
step is required. In such cases :meth:`take_glimpses` should specify |
|
81
|
|
|
such need in its interface (its docstring explains how to do that). In |
|
82
|
|
|
addition :meth:`initial_glimpses` should specify some sensible |
|
83
|
|
|
initialization for the glimpses to be carried over. |
|
84
|
|
|
|
|
85
|
|
|
.. todo:: |
|
86
|
|
|
|
|
87
|
|
|
Only single attended is currently allowed. |
|
88
|
|
|
|
|
89
|
|
|
:meth:`preprocess` and :meth:`initial_glimpses` might end up |
|
90
|
|
|
needing masks, which are currently not provided for them. |
|
91
|
|
|
|
|
92
|
|
|
Parameters |
|
93
|
|
|
---------- |
|
94
|
|
|
state_names : list |
|
95
|
|
|
The names of the network states. |
|
96
|
|
|
state_dims : list |
|
97
|
|
|
The state dimensions corresponding to `state_names`. |
|
98
|
|
|
attended_dim : int |
|
99
|
|
|
The dimension of the attended. |
|
100
|
|
|
|
|
101
|
|
|
Attributes |
|
102
|
|
|
---------- |
|
103
|
|
|
state_names : list |
|
104
|
|
|
state_dims : list |
|
105
|
|
|
attended_dim : int |
|
106
|
|
|
|
|
107
|
|
|
""" |
|
108
|
|
|
@lazy(allocation=['state_names', 'state_dims', 'attended_dim']) |
|
109
|
|
|
def __init__(self, state_names, state_dims, attended_dim, **kwargs): |
|
110
|
|
|
self.state_names = state_names |
|
111
|
|
|
self.state_dims = state_dims |
|
112
|
|
|
self.attended_dim = attended_dim |
|
113
|
|
|
super(AbstractAttention, self).__init__(**kwargs) |
|
114
|
|
|
|
|
115
|
|
|
@application(inputs=['attended'], outputs=['preprocessed_attended']) |
|
116
|
|
|
def preprocess(self, attended): |
|
117
|
|
|
"""Perform the preprocessing of the attended. |
|
118
|
|
|
|
|
119
|
|
|
Stage 1 of the attention mechanism, see :class:`AbstractAttention` |
|
120
|
|
|
docstring for an explanation of stages. The default implementation |
|
121
|
|
|
simply returns attended. |
|
122
|
|
|
|
|
123
|
|
|
Parameters |
|
124
|
|
|
---------- |
|
125
|
|
|
attended : :class:`~theano.Variable` |
|
126
|
|
|
The attended. |
|
127
|
|
|
|
|
128
|
|
|
Returns |
|
129
|
|
|
------- |
|
130
|
|
|
preprocessed_attended : :class:`~theano.Variable` |
|
131
|
|
|
The preprocessed attended. |
|
132
|
|
|
|
|
133
|
|
|
""" |
|
134
|
|
|
return attended |
|
135
|
|
|
|
|
136
|
|
|
@abstractmethod |
|
137
|
|
|
def take_glimpses(self, attended, preprocessed_attended=None, |
|
138
|
|
|
attended_mask=None, **kwargs): |
|
139
|
|
|
r"""Extract glimpses from the attended given the current states. |
|
140
|
|
|
|
|
141
|
|
|
Stage 2 of the attention mechanism, see :class:`AbstractAttention` |
|
142
|
|
|
for an explanation of stages. If `preprocessed_attended` is not |
|
143
|
|
|
given, should trigger the stage 1. |
|
144
|
|
|
|
|
145
|
|
|
This application method *must* declare its inputs and outputs. |
|
146
|
|
|
The glimpses to be carried over are identified by their presence |
|
147
|
|
|
in both inputs and outputs list. The attended *must* be the first |
|
148
|
|
|
input, the preprocessed attended *must* be the second one. |
|
149
|
|
|
|
|
150
|
|
|
Parameters |
|
151
|
|
|
---------- |
|
152
|
|
|
attended : :class:`~theano.Variable` |
|
153
|
|
|
The attended. |
|
154
|
|
|
preprocessed_attended : :class:`~theano.Variable`, optional |
|
155
|
|
|
The preprocessed attended computed by :meth:`preprocess`. When |
|
156
|
|
|
not given, :meth:`preprocess` should be called. |
|
157
|
|
|
attended_mask : :class:`~theano.Variable`, optional |
|
158
|
|
|
The mask for the attended. This is required in the case of |
|
159
|
|
|
padded structured output, e.g. when a number of sequences are |
|
160
|
|
|
force to be the same length. The mask identifies position of |
|
161
|
|
|
the `attended` that actually contain information. |
|
162
|
|
|
\*\*kwargs : dict |
|
163
|
|
|
Includes the states and the glimpses to be carried over from |
|
164
|
|
|
the previous step in the case when the attention mechanism is |
|
165
|
|
|
applied sequentially. |
|
166
|
|
|
|
|
167
|
|
|
""" |
|
168
|
|
|
pass |
|
169
|
|
|
|
|
170
|
|
|
@abstractmethod |
|
171
|
|
|
def initial_glimpses(self, batch_size, attended): |
|
172
|
|
|
"""Return sensible initial values for carried over glimpses. |
|
173
|
|
|
|
|
174
|
|
|
Parameters |
|
175
|
|
|
---------- |
|
176
|
|
|
batch_size : int or :class:`~theano.Variable` |
|
177
|
|
|
The batch size. |
|
178
|
|
|
attended : :class:`~theano.Variable` |
|
179
|
|
|
The attended. |
|
180
|
|
|
|
|
181
|
|
|
Returns |
|
182
|
|
|
------- |
|
183
|
|
|
initial_glimpses : list of :class:`~theano.Variable` |
|
184
|
|
|
The initial values for the requested glimpses. These might |
|
185
|
|
|
simply consist of zeros or be somehow extracted from |
|
186
|
|
|
the attended. |
|
187
|
|
|
|
|
188
|
|
|
""" |
|
189
|
|
|
pass |
|
190
|
|
|
|
|
191
|
|
|
def get_dim(self, name): |
|
192
|
|
|
if name in ['attended', 'preprocessed_attended']: |
|
193
|
|
|
return self.attended_dim |
|
194
|
|
|
if name in ['attended_mask']: |
|
195
|
|
|
return 0 |
|
196
|
|
|
return super(AbstractAttention, self).get_dim(name) |
|
197
|
|
|
|
|
198
|
|
|
|
|
199
|
|
|
class GenericSequenceAttention(AbstractAttention): |
|
200
|
|
|
"""Logic common for sequence attention mechanisms.""" |
|
201
|
|
|
@application |
|
202
|
|
|
def compute_weights(self, energies, attended_mask): |
|
203
|
|
|
"""Compute weights from energies in softmax-like fashion. |
|
204
|
|
|
|
|
205
|
|
|
.. todo :: |
|
206
|
|
|
|
|
207
|
|
|
Use :class:`~blocks.bricks.Softmax`. |
|
208
|
|
|
|
|
209
|
|
|
Parameters |
|
210
|
|
|
---------- |
|
211
|
|
|
energies : :class:`~theano.Variable` |
|
212
|
|
|
The energies. Must be of the same shape as the mask. |
|
213
|
|
|
attended_mask : :class:`~theano.Variable` |
|
214
|
|
|
The mask for the attended. The index in the sequence must be |
|
215
|
|
|
the first dimension. |
|
216
|
|
|
|
|
217
|
|
|
Returns |
|
218
|
|
|
------- |
|
219
|
|
|
weights : :class:`~theano.Variable` |
|
220
|
|
|
Summing to 1 non-negative weights of the same shape |
|
221
|
|
|
as `energies`. |
|
222
|
|
|
|
|
223
|
|
|
""" |
|
224
|
|
|
# Stabilize energies first and then exponentiate |
|
225
|
|
|
energies = energies - energies.max(axis=0) |
|
226
|
|
|
unnormalized_weights = tensor.exp(energies) |
|
227
|
|
|
if attended_mask: |
|
228
|
|
|
unnormalized_weights *= attended_mask |
|
229
|
|
|
|
|
230
|
|
|
# If mask consists of all zeros use 1 as the normalization coefficient |
|
231
|
|
|
normalization = (unnormalized_weights.sum(axis=0) + |
|
232
|
|
|
tensor.all(1 - attended_mask, axis=0)) |
|
233
|
|
|
return unnormalized_weights / normalization |
|
234
|
|
|
|
|
235
|
|
|
@application |
|
236
|
|
|
def compute_weighted_averages(self, weights, attended): |
|
237
|
|
|
"""Compute weighted averages of the attended sequence vectors. |
|
238
|
|
|
|
|
239
|
|
|
Parameters |
|
240
|
|
|
---------- |
|
241
|
|
|
weights : :class:`~theano.Variable` |
|
242
|
|
|
The weights. The shape must be equal to the attended shape |
|
243
|
|
|
without the last dimension. |
|
244
|
|
|
attended : :class:`~theano.Variable` |
|
245
|
|
|
The attended. The index in the sequence must be the first |
|
246
|
|
|
dimension. |
|
247
|
|
|
|
|
248
|
|
|
Returns |
|
249
|
|
|
------- |
|
250
|
|
|
weighted_averages : :class:`~theano.Variable` |
|
251
|
|
|
The weighted averages of the attended elements. The shape |
|
252
|
|
|
is equal to the attended shape with the first dimension |
|
253
|
|
|
dropped. |
|
254
|
|
|
|
|
255
|
|
|
""" |
|
256
|
|
|
return (tensor.shape_padright(weights) * attended).sum(axis=0) |
|
257
|
|
|
|
|
258
|
|
|
|
|
259
|
|
|
class SequenceContentAttention(GenericSequenceAttention, Initializable): |
|
260
|
|
|
"""Attention mechanism that looks for relevant content in a sequence. |
|
261
|
|
|
|
|
262
|
|
|
This is the attention mechanism used in [BCB]_. The idea in a nutshell: |
|
263
|
|
|
|
|
264
|
|
|
1. The states and the sequence are transformed independently, |
|
265
|
|
|
|
|
266
|
|
|
2. The transformed states are summed with every transformed sequence |
|
267
|
|
|
element to obtain *match vectors*, |
|
268
|
|
|
|
|
269
|
|
|
3. A match vector is transformed into a single number interpreted as |
|
270
|
|
|
*energy*, |
|
271
|
|
|
|
|
272
|
|
|
4. Energies are normalized in softmax-like fashion. The resulting |
|
273
|
|
|
summing to one weights are called *attention weights*, |
|
274
|
|
|
|
|
275
|
|
|
5. Weighted average of the sequence elements with attention weights |
|
276
|
|
|
is computed. |
|
277
|
|
|
|
|
278
|
|
|
In terms of the :class:`AbstractAttention` documentation, the sequence |
|
279
|
|
|
is the attended. The weighted averages from 5 and the attention |
|
280
|
|
|
weights from 4 form the set of glimpses produced by this attention |
|
281
|
|
|
mechanism. |
|
282
|
|
|
|
|
283
|
|
|
Parameters |
|
284
|
|
|
---------- |
|
285
|
|
|
state_names : list of str |
|
286
|
|
|
The names of the network states. |
|
287
|
|
|
attended_dim : int |
|
288
|
|
|
The dimension of the sequence elements. |
|
289
|
|
|
match_dim : int |
|
290
|
|
|
The dimension of the match vector. |
|
291
|
|
|
state_transformer : :class:`~.bricks.Brick` |
|
292
|
|
|
A prototype for state transformations. If ``None``, |
|
293
|
|
|
a linear transformation is used. |
|
294
|
|
|
attended_transformer : :class:`.Feedforward` |
|
295
|
|
|
The transformation to be applied to the sequence. If ``None`` an |
|
296
|
|
|
affine transformation is used. |
|
297
|
|
|
energy_computer : :class:`.Feedforward` |
|
298
|
|
|
Computes energy from the match vector. If ``None``, an affine |
|
299
|
|
|
transformations preceeded by :math:`tanh` is used. |
|
300
|
|
|
|
|
301
|
|
|
Notes |
|
302
|
|
|
----- |
|
303
|
|
|
See :class:`.Initializable` for initialization parameters. |
|
304
|
|
|
|
|
305
|
|
|
.. [BCB] Dzmitry Bahdanau, Kyunghyun Cho and Yoshua Bengio. Neural |
|
306
|
|
|
Machine Translation by Jointly Learning to Align and Translate. |
|
307
|
|
|
|
|
308
|
|
|
""" |
|
309
|
|
|
@lazy(allocation=['match_dim']) |
|
310
|
|
|
def __init__(self, match_dim, state_transformer=None, |
|
311
|
|
|
attended_transformer=None, energy_computer=None, **kwargs): |
|
312
|
|
|
if not state_transformer: |
|
313
|
|
|
state_transformer = Linear(use_bias=False) |
|
314
|
|
|
self.match_dim = match_dim |
|
315
|
|
|
self.state_transformer = state_transformer |
|
316
|
|
|
|
|
317
|
|
|
self.state_transformers = Parallel(input_names=kwargs['state_names'], |
|
318
|
|
|
prototype=state_transformer, |
|
319
|
|
|
name="state_trans") |
|
320
|
|
|
if not attended_transformer: |
|
321
|
|
|
attended_transformer = Linear(name="preprocess") |
|
322
|
|
|
if not energy_computer: |
|
323
|
|
|
energy_computer = ShallowEnergyComputer(name="energy_comp") |
|
324
|
|
|
self.attended_transformer = attended_transformer |
|
325
|
|
|
self.energy_computer = energy_computer |
|
326
|
|
|
|
|
327
|
|
|
children = [self.state_transformers, attended_transformer, |
|
328
|
|
|
energy_computer] |
|
329
|
|
|
kwargs.setdefault('children', []).extend(children) |
|
330
|
|
|
super(SequenceContentAttention, self).__init__(**kwargs) |
|
331
|
|
|
|
|
332
|
|
|
def _push_allocation_config(self): |
|
333
|
|
|
self.state_transformers.input_dims = self.state_dims |
|
334
|
|
|
self.state_transformers.output_dims = [self.match_dim |
|
335
|
|
|
for name in self.state_names] |
|
336
|
|
|
self.attended_transformer.input_dim = self.attended_dim |
|
337
|
|
|
self.attended_transformer.output_dim = self.match_dim |
|
338
|
|
|
self.energy_computer.input_dim = self.match_dim |
|
339
|
|
|
self.energy_computer.output_dim = 1 |
|
340
|
|
|
|
|
341
|
|
|
@application |
|
342
|
|
|
def compute_energies(self, attended, preprocessed_attended, states): |
|
343
|
|
|
if not preprocessed_attended: |
|
344
|
|
|
preprocessed_attended = self.preprocess(attended) |
|
345
|
|
|
transformed_states = self.state_transformers.apply(as_dict=True, |
|
346
|
|
|
**states) |
|
347
|
|
|
# Broadcasting of transformed states should be done automatically |
|
348
|
|
|
match_vectors = sum(transformed_states.values(), |
|
349
|
|
|
preprocessed_attended) |
|
350
|
|
|
energies = self.energy_computer.apply(match_vectors).reshape( |
|
351
|
|
|
match_vectors.shape[:-1], ndim=match_vectors.ndim - 1) |
|
352
|
|
|
return energies |
|
353
|
|
|
|
|
354
|
|
|
@application(outputs=['weighted_averages', 'weights']) |
|
355
|
|
|
def take_glimpses(self, attended, preprocessed_attended=None, |
|
356
|
|
|
attended_mask=None, **states): |
|
357
|
|
|
r"""Compute attention weights and produce glimpses. |
|
358
|
|
|
|
|
359
|
|
|
Parameters |
|
360
|
|
|
---------- |
|
361
|
|
|
attended : :class:`~tensor.TensorVariable` |
|
362
|
|
|
The sequence, time is the 1-st dimension. |
|
363
|
|
|
preprocessed_attended : :class:`~tensor.TensorVariable` |
|
364
|
|
|
The preprocessed sequence. If ``None``, is computed by calling |
|
365
|
|
|
:meth:`preprocess`. |
|
366
|
|
|
attended_mask : :class:`~tensor.TensorVariable` |
|
367
|
|
|
A 0/1 mask specifying available data. 0 means that the |
|
368
|
|
|
corresponding sequence element is fake. |
|
369
|
|
|
\*\*states |
|
370
|
|
|
The states of the network. |
|
371
|
|
|
|
|
372
|
|
|
Returns |
|
373
|
|
|
------- |
|
374
|
|
|
weighted_averages : :class:`~theano.Variable` |
|
375
|
|
|
Linear combinations of sequence elements with the attention |
|
376
|
|
|
weights. |
|
377
|
|
|
weights : :class:`~theano.Variable` |
|
378
|
|
|
The attention weights. The first dimension is batch, the second |
|
379
|
|
|
is time. |
|
380
|
|
|
|
|
381
|
|
|
""" |
|
382
|
|
|
energies = self.compute_energies(attended, preprocessed_attended, |
|
383
|
|
|
states) |
|
384
|
|
|
weights = self.compute_weights(energies, attended_mask) |
|
385
|
|
|
weighted_averages = self.compute_weighted_averages(weights, attended) |
|
386
|
|
|
return weighted_averages, weights.T |
|
387
|
|
|
|
|
388
|
|
|
@take_glimpses.property('inputs') |
|
389
|
|
|
def take_glimpses_inputs(self): |
|
390
|
|
|
return (['attended', 'preprocessed_attended', 'attended_mask'] + |
|
391
|
|
|
self.state_names) |
|
392
|
|
|
|
|
393
|
|
|
@application(outputs=['weighted_averages', 'weights']) |
|
394
|
|
|
def initial_glimpses(self, batch_size, attended): |
|
395
|
|
|
return [tensor.zeros((batch_size, self.attended_dim)), |
|
396
|
|
|
tensor.zeros((batch_size, attended.shape[0]))] |
|
397
|
|
|
|
|
398
|
|
|
@application(inputs=['attended'], outputs=['preprocessed_attended']) |
|
399
|
|
|
def preprocess(self, attended): |
|
400
|
|
|
"""Preprocess the sequence for computing attention weights. |
|
401
|
|
|
|
|
402
|
|
|
Parameters |
|
403
|
|
|
---------- |
|
404
|
|
|
attended : :class:`~tensor.TensorVariable` |
|
405
|
|
|
The attended sequence, time is the 1-st dimension. |
|
406
|
|
|
|
|
407
|
|
|
""" |
|
408
|
|
|
return self.attended_transformer.apply(attended) |
|
409
|
|
|
|
|
410
|
|
|
def get_dim(self, name): |
|
411
|
|
|
if name in ['weighted_averages']: |
|
412
|
|
|
return self.attended_dim |
|
413
|
|
|
if name in ['weights']: |
|
414
|
|
|
return 0 |
|
415
|
|
|
return super(SequenceContentAttention, self).get_dim(name) |
|
416
|
|
|
|
|
417
|
|
|
|
|
418
|
|
|
class ShallowEnergyComputer(Sequence, Initializable, Feedforward): |
|
419
|
|
|
"""A simple energy computer: first tanh, then weighted sum. |
|
420
|
|
|
|
|
421
|
|
|
Parameters |
|
422
|
|
|
---------- |
|
423
|
|
|
use_bias : bool, optional |
|
424
|
|
|
Whether a bias should be added to the energies. Does not change |
|
425
|
|
|
anything if softmax normalization is used to produce the attention |
|
426
|
|
|
weights, but might be useful when e.g. spherical softmax is used. |
|
427
|
|
|
|
|
428
|
|
|
""" |
|
429
|
|
|
@lazy() |
|
430
|
|
|
def __init__(self, use_bias=False, **kwargs): |
|
431
|
|
|
super(ShallowEnergyComputer, self).__init__( |
|
432
|
|
|
[Tanh().apply, Linear(use_bias=use_bias).apply], **kwargs) |
|
433
|
|
|
|
|
434
|
|
|
@property |
|
435
|
|
|
def input_dim(self): |
|
436
|
|
|
return self.children[1].input_dim |
|
437
|
|
|
|
|
438
|
|
|
@input_dim.setter |
|
439
|
|
|
def input_dim(self, value): |
|
440
|
|
|
self.children[1].input_dim = value |
|
441
|
|
|
|
|
442
|
|
|
@property |
|
443
|
|
|
def output_dim(self): |
|
444
|
|
|
return self.children[1].output_dim |
|
445
|
|
|
|
|
446
|
|
|
@output_dim.setter |
|
447
|
|
|
def output_dim(self, value): |
|
448
|
|
|
self.children[1].output_dim = value |
|
449
|
|
|
|
|
450
|
|
|
|
|
451
|
|
|
@add_metaclass(ABCMeta) |
|
452
|
|
|
class AbstractAttentionRecurrent(BaseRecurrent): |
|
453
|
|
|
"""The interface for attention-equipped recurrent transitions. |
|
454
|
|
|
|
|
455
|
|
|
When a recurrent network is equipped with an attention mechanism its |
|
456
|
|
|
transition typically consists of two steps: (1) the glimpses are taken |
|
457
|
|
|
by the attention mechanism and (2) the next states are computed using |
|
458
|
|
|
the current states and the glimpses. It is required for certain |
|
459
|
|
|
usecases (such as sequence generator) that apart from a do-it-all |
|
460
|
|
|
recurrent application method interfaces for the first step and |
|
461
|
|
|
the second steps of the transition are provided. |
|
462
|
|
|
|
|
463
|
|
|
""" |
|
464
|
|
|
@abstractmethod |
|
465
|
|
|
def apply(self, **kwargs): |
|
466
|
|
|
"""Compute next states taking glimpses on the way.""" |
|
467
|
|
|
pass |
|
468
|
|
|
|
|
469
|
|
|
@abstractmethod |
|
470
|
|
|
def take_glimpses(self, **kwargs): |
|
471
|
|
|
"""Compute glimpses given the current states.""" |
|
472
|
|
|
pass |
|
473
|
|
|
|
|
474
|
|
|
@abstractmethod |
|
475
|
|
|
def compute_states(self, **kwargs): |
|
476
|
|
|
"""Compute next states given current states and glimpses.""" |
|
477
|
|
|
pass |
|
478
|
|
|
|
|
479
|
|
|
|
|
480
|
|
|
class AttentionRecurrent(AbstractAttentionRecurrent, Initializable): |
|
481
|
|
|
"""Combines an attention mechanism and a recurrent transition. |
|
482
|
|
|
|
|
483
|
|
|
This brick equips a recurrent transition with an attention mechanism. |
|
484
|
|
|
In order to do this two more contexts are added: one to be attended and |
|
485
|
|
|
a mask for it. It is also possible to use the contexts of the given |
|
486
|
|
|
recurrent transition for these purposes and not add any new ones, |
|
487
|
|
|
see `add_context` parameter. |
|
488
|
|
|
|
|
489
|
|
|
At the beginning of each step attention mechanism produces glimpses; |
|
490
|
|
|
these glimpses together with the current states are used to compute the |
|
491
|
|
|
next state and finish the transition. In some cases glimpses from the |
|
492
|
|
|
previous steps are also necessary for the attention mechanism, e.g. |
|
493
|
|
|
in order to focus on an area close to the one from the previous step. |
|
494
|
|
|
This is also supported: such glimpses become states of the new |
|
495
|
|
|
transition. |
|
496
|
|
|
|
|
497
|
|
|
To let the user control the way glimpses are used, this brick also |
|
498
|
|
|
takes a "distribute" brick as parameter that distributes the |
|
499
|
|
|
information from glimpses across the sequential inputs of the wrapped |
|
500
|
|
|
recurrent transition. |
|
501
|
|
|
|
|
502
|
|
|
Parameters |
|
503
|
|
|
---------- |
|
504
|
|
|
transition : :class:`.BaseRecurrent` |
|
505
|
|
|
The recurrent transition. |
|
506
|
|
|
attention : :class:`~.bricks.Brick` |
|
507
|
|
|
The attention mechanism. |
|
508
|
|
|
distribute : :class:`~.bricks.Brick`, optional |
|
509
|
|
|
Distributes the information from glimpses across the input |
|
510
|
|
|
sequences of the transition. By default a :class:`.Distribute` is |
|
511
|
|
|
used, and those inputs containing the "mask" substring in their |
|
512
|
|
|
name are not affected. |
|
513
|
|
|
add_contexts : bool, optional |
|
514
|
|
|
If ``True``, new contexts for the attended and the attended mask |
|
515
|
|
|
are added to this transition, otherwise existing contexts of the |
|
516
|
|
|
wrapped transition are used. ``True`` by default. |
|
517
|
|
|
attended_name : str |
|
518
|
|
|
The name of the attended context. If ``None``, "attended" |
|
519
|
|
|
or the first context of the recurrent transition is used |
|
520
|
|
|
depending on the value of `add_contents` flag. |
|
521
|
|
|
attended_mask_name : str |
|
522
|
|
|
The name of the mask for the attended context. If ``None``, |
|
523
|
|
|
"attended_mask" or the second context of the recurrent transition |
|
524
|
|
|
is used depending on the value of `add_contents` flag. |
|
525
|
|
|
|
|
526
|
|
|
Notes |
|
527
|
|
|
----- |
|
528
|
|
|
See :class:`.Initializable` for initialization parameters. |
|
529
|
|
|
|
|
530
|
|
|
Wrapping your recurrent brick with this class makes all the |
|
531
|
|
|
states mandatory. If you feel this is a limitation for you, try |
|
532
|
|
|
to make it better! This restriction does not apply to sequences |
|
533
|
|
|
and contexts: those keep being as optional as they were for |
|
534
|
|
|
your brick. |
|
535
|
|
|
|
|
536
|
|
|
Those coming to Blocks from Groundhog might recognize that this is |
|
537
|
|
|
a `RecurrentLayerWithSearch`, but on steroids :) |
|
538
|
|
|
|
|
539
|
|
|
""" |
|
540
|
|
|
def __init__(self, transition, attention, distribute=None, |
|
541
|
|
|
add_contexts=True, |
|
542
|
|
|
attended_name=None, attended_mask_name=None, |
|
543
|
|
|
**kwargs): |
|
544
|
|
|
self._sequence_names = list(transition.apply.sequences) |
|
545
|
|
|
self._state_names = list(transition.apply.states) |
|
546
|
|
|
self._context_names = list(transition.apply.contexts) |
|
547
|
|
|
if add_contexts: |
|
548
|
|
|
if not attended_name: |
|
549
|
|
|
attended_name = 'attended' |
|
550
|
|
|
if not attended_mask_name: |
|
551
|
|
|
attended_mask_name = 'attended_mask' |
|
552
|
|
|
self._context_names += [attended_name, attended_mask_name] |
|
553
|
|
|
else: |
|
554
|
|
|
attended_name = self._context_names[0] |
|
555
|
|
|
attended_mask_name = self._context_names[1] |
|
556
|
|
|
if not distribute: |
|
557
|
|
|
normal_inputs = [name for name in self._sequence_names |
|
558
|
|
|
if 'mask' not in name] |
|
559
|
|
|
distribute = Distribute(normal_inputs, |
|
560
|
|
|
attention.take_glimpses.outputs[0]) |
|
561
|
|
|
|
|
562
|
|
|
self.transition = transition |
|
563
|
|
|
self.attention = attention |
|
564
|
|
|
self.distribute = distribute |
|
565
|
|
|
self.add_contexts = add_contexts |
|
566
|
|
|
self.attended_name = attended_name |
|
567
|
|
|
self.attended_mask_name = attended_mask_name |
|
568
|
|
|
|
|
569
|
|
|
self.preprocessed_attended_name = "preprocessed_" + self.attended_name |
|
570
|
|
|
|
|
571
|
|
|
self._glimpse_names = self.attention.take_glimpses.outputs |
|
572
|
|
|
# We need to determine which glimpses are fed back. |
|
573
|
|
|
# Currently we extract it from `take_glimpses` signature. |
|
574
|
|
|
self.previous_glimpses_needed = [ |
|
575
|
|
|
name for name in self._glimpse_names |
|
576
|
|
|
if name in self.attention.take_glimpses.inputs] |
|
577
|
|
|
|
|
578
|
|
|
children = [self.transition, self.attention, self.distribute] |
|
579
|
|
|
kwargs.setdefault('children', []).extend(children) |
|
580
|
|
|
super(AttentionRecurrent, self).__init__(**kwargs) |
|
581
|
|
|
|
|
582
|
|
|
def _push_allocation_config(self): |
|
583
|
|
|
self.attention.state_dims = self.transition.get_dims( |
|
584
|
|
|
self.attention.state_names) |
|
585
|
|
|
self.attention.attended_dim = self.get_dim(self.attended_name) |
|
586
|
|
|
self.distribute.source_dim = self.attention.get_dim( |
|
587
|
|
|
self.distribute.source_name) |
|
588
|
|
|
self.distribute.target_dims = self.transition.get_dims( |
|
589
|
|
|
self.distribute.target_names) |
|
590
|
|
|
|
|
591
|
|
|
@application |
|
592
|
|
|
def take_glimpses(self, **kwargs): |
|
593
|
|
|
r"""Compute glimpses with the attention mechanism. |
|
594
|
|
|
|
|
595
|
|
|
A thin wrapper over `self.attention.take_glimpses`: takes care |
|
596
|
|
|
of choosing and renaming the necessary arguments. |
|
597
|
|
|
|
|
598
|
|
|
Parameters |
|
599
|
|
|
---------- |
|
600
|
|
|
\*\*kwargs |
|
601
|
|
|
Must contain the attended, previous step states and glimpses. |
|
602
|
|
|
Can optionaly contain the attended mask and the preprocessed |
|
603
|
|
|
attended. |
|
604
|
|
|
|
|
605
|
|
|
Returns |
|
606
|
|
|
------- |
|
607
|
|
|
glimpses : list of :class:`~tensor.TensorVariable` |
|
608
|
|
|
Current step glimpses. |
|
609
|
|
|
|
|
610
|
|
|
""" |
|
611
|
|
|
states = dict_subset(kwargs, self._state_names, pop=True) |
|
612
|
|
|
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) |
|
613
|
|
|
glimpses_needed = dict_subset(glimpses, self.previous_glimpses_needed) |
|
614
|
|
|
result = self.attention.take_glimpses( |
|
615
|
|
|
kwargs.pop(self.attended_name), |
|
616
|
|
|
kwargs.pop(self.preprocessed_attended_name, None), |
|
617
|
|
|
kwargs.pop(self.attended_mask_name, None), |
|
618
|
|
|
**dict_union(states, glimpses_needed)) |
|
619
|
|
|
# At this point kwargs may contain additional items. |
|
620
|
|
|
# e.g. AttentionRecurrent.transition.apply.contexts |
|
621
|
|
|
return result |
|
622
|
|
|
|
|
623
|
|
|
@take_glimpses.property('outputs') |
|
624
|
|
|
def take_glimpses_outputs(self): |
|
625
|
|
|
return self._glimpse_names |
|
626
|
|
|
|
|
627
|
|
|
@application |
|
628
|
|
|
def compute_states(self, **kwargs): |
|
629
|
|
|
r"""Compute current states when glimpses have already been computed. |
|
630
|
|
|
|
|
631
|
|
|
Combines an application of the `distribute` that alter the |
|
632
|
|
|
sequential inputs of the wrapped transition and an application of |
|
633
|
|
|
the wrapped transition. All unknown keyword arguments go to |
|
634
|
|
|
the wrapped transition. |
|
635
|
|
|
|
|
636
|
|
|
Parameters |
|
637
|
|
|
---------- |
|
638
|
|
|
\*\*kwargs |
|
639
|
|
|
Should contain everything what `self.transition` needs |
|
640
|
|
|
and in addition the current glimpses. |
|
641
|
|
|
|
|
642
|
|
|
Returns |
|
643
|
|
|
------- |
|
644
|
|
|
current_states : list of :class:`~tensor.TensorVariable` |
|
645
|
|
|
Current states computed by `self.transition`. |
|
646
|
|
|
|
|
647
|
|
|
""" |
|
648
|
|
|
# make sure we are not popping the mask |
|
649
|
|
|
normal_inputs = [name for name in self._sequence_names |
|
650
|
|
|
if 'mask' not in name] |
|
651
|
|
|
sequences = dict_subset(kwargs, normal_inputs, pop=True) |
|
652
|
|
|
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) |
|
653
|
|
|
if self.add_contexts: |
|
654
|
|
|
kwargs.pop(self.attended_name) |
|
655
|
|
|
# attended_mask_name can be optional |
|
656
|
|
|
kwargs.pop(self.attended_mask_name, None) |
|
657
|
|
|
|
|
658
|
|
|
sequences.update(self.distribute.apply( |
|
659
|
|
|
as_dict=True, **dict_subset(dict_union(sequences, glimpses), |
|
660
|
|
|
self.distribute.apply.inputs))) |
|
661
|
|
|
current_states = self.transition.apply( |
|
662
|
|
|
iterate=False, as_list=True, |
|
663
|
|
|
**dict_union(sequences, kwargs)) |
|
664
|
|
|
return current_states |
|
665
|
|
|
|
|
666
|
|
|
@compute_states.property('outputs') |
|
667
|
|
|
def compute_states_outputs(self): |
|
668
|
|
|
return self._state_names |
|
669
|
|
|
|
|
670
|
|
|
@recurrent |
|
671
|
|
|
def do_apply(self, **kwargs): |
|
672
|
|
|
r"""Process a sequence attending the attended context every step. |
|
673
|
|
|
|
|
674
|
|
|
In addition to the original sequence this method also requires |
|
675
|
|
|
its preprocessed version, the one computed by the `preprocess` |
|
676
|
|
|
method of the attention mechanism. Unknown keyword arguments |
|
677
|
|
|
are passed to the wrapped transition. |
|
678
|
|
|
|
|
679
|
|
|
Parameters |
|
680
|
|
|
---------- |
|
681
|
|
|
\*\*kwargs |
|
682
|
|
|
Should contain current inputs, previous step states, contexts, |
|
683
|
|
|
the preprocessed attended context, previous step glimpses. |
|
684
|
|
|
|
|
685
|
|
|
Returns |
|
686
|
|
|
------- |
|
687
|
|
|
outputs : list of :class:`~tensor.TensorVariable` |
|
688
|
|
|
The current step states and glimpses. |
|
689
|
|
|
|
|
690
|
|
|
""" |
|
691
|
|
|
attended = kwargs[self.attended_name] |
|
692
|
|
|
preprocessed_attended = kwargs.pop(self.preprocessed_attended_name) |
|
693
|
|
|
attended_mask = kwargs.get(self.attended_mask_name) |
|
694
|
|
|
sequences = dict_subset(kwargs, self._sequence_names, pop=True, |
|
695
|
|
|
must_have=False) |
|
696
|
|
|
states = dict_subset(kwargs, self._state_names, pop=True) |
|
697
|
|
|
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) |
|
698
|
|
|
|
|
699
|
|
|
current_glimpses = self.take_glimpses( |
|
700
|
|
|
as_dict=True, |
|
701
|
|
|
**dict_union( |
|
702
|
|
|
states, glimpses, |
|
703
|
|
|
{self.attended_name: attended, |
|
704
|
|
|
self.attended_mask_name: attended_mask, |
|
705
|
|
|
self.preprocessed_attended_name: preprocessed_attended})) |
|
706
|
|
|
current_states = self.compute_states( |
|
707
|
|
|
as_list=True, |
|
708
|
|
|
**dict_union(sequences, states, current_glimpses, kwargs)) |
|
709
|
|
|
return current_states + list(current_glimpses.values()) |
|
710
|
|
|
|
|
711
|
|
|
@do_apply.property('sequences') |
|
712
|
|
|
def do_apply_sequences(self): |
|
713
|
|
|
return self._sequence_names |
|
714
|
|
|
|
|
715
|
|
|
@do_apply.property('contexts') |
|
716
|
|
|
def do_apply_contexts(self): |
|
717
|
|
|
return self._context_names + [self.preprocessed_attended_name] |
|
718
|
|
|
|
|
719
|
|
|
@do_apply.property('states') |
|
720
|
|
|
def do_apply_states(self): |
|
721
|
|
|
return self._state_names + self._glimpse_names |
|
722
|
|
|
|
|
723
|
|
|
@do_apply.property('outputs') |
|
724
|
|
|
def do_apply_outputs(self): |
|
725
|
|
|
return self._state_names + self._glimpse_names |
|
726
|
|
|
|
|
727
|
|
|
@application |
|
728
|
|
|
def apply(self, **kwargs): |
|
729
|
|
|
"""Preprocess a sequence attending the attended context at every step. |
|
730
|
|
|
|
|
731
|
|
|
Preprocesses the attended context and runs :meth:`do_apply`. See |
|
732
|
|
|
:meth:`do_apply` documentation for further information. |
|
733
|
|
|
|
|
734
|
|
|
""" |
|
735
|
|
|
preprocessed_attended = self.attention.preprocess( |
|
736
|
|
|
kwargs[self.attended_name]) |
|
737
|
|
|
return self.do_apply( |
|
738
|
|
|
**dict_union(kwargs, |
|
739
|
|
|
{self.preprocessed_attended_name: |
|
740
|
|
|
preprocessed_attended})) |
|
741
|
|
|
|
|
742
|
|
|
@apply.delegate |
|
743
|
|
|
def apply_delegate(self): |
|
744
|
|
|
# TODO: Nice interface for this trick? |
|
745
|
|
|
return self.do_apply.__get__(self, None) |
|
746
|
|
|
|
|
747
|
|
|
@apply.property('contexts') |
|
748
|
|
|
def apply_contexts(self): |
|
749
|
|
|
return self._context_names |
|
750
|
|
|
|
|
751
|
|
|
@application |
|
752
|
|
|
def initial_states(self, batch_size, **kwargs): |
|
753
|
|
|
return (pack(self.transition.initial_states( |
|
754
|
|
|
batch_size, **kwargs)) + |
|
755
|
|
|
pack(self.attention.initial_glimpses( |
|
756
|
|
|
batch_size, kwargs[self.attended_name]))) |
|
757
|
|
|
|
|
758
|
|
|
@initial_states.property('outputs') |
|
759
|
|
|
def initial_states_outputs(self): |
|
760
|
|
|
return self.do_apply.states |
|
761
|
|
|
|
|
762
|
|
|
def get_dim(self, name): |
|
763
|
|
|
if name in self._glimpse_names: |
|
764
|
|
|
return self.attention.get_dim(name) |
|
765
|
|
|
if name == self.preprocessed_attended_name: |
|
766
|
|
|
(original_name,) = self.attention.preprocess.outputs |
|
767
|
|
|
return self.attention.get_dim(original_name) |
|
768
|
|
|
if self.add_contexts: |
|
769
|
|
|
if name == self.attended_name: |
|
770
|
|
|
return self.attention.get_dim( |
|
771
|
|
|
self.attention.take_glimpses.inputs[0]) |
|
772
|
|
|
if name == self.attended_mask_name: |
|
773
|
|
|
return 0 |
|
774
|
|
|
return self.transition.get_dim(name) |
|
775
|
|
|
|