1
|
|
|
"""Attention mechanisms. |
2
|
|
|
|
3
|
|
|
This module defines the interface of attention mechanisms and a few |
4
|
|
|
concrete implementations. For a gentle introduction and usage examples see |
5
|
|
|
the tutorial TODO. |
6
|
|
|
|
7
|
|
|
An attention mechanism decides to what part of the input to pay attention. |
8
|
|
|
It is typically used as a component of a recurrent network, though one can |
9
|
|
|
imagine it used in other conditions as well. When the input is big and has |
10
|
|
|
certain structure, for instance when it is sequence or an image, an |
11
|
|
|
attention mechanism can be applied to extract only information which is |
12
|
|
|
relevant for the network in its current state. |
13
|
|
|
|
14
|
|
|
For the purpose of documentation clarity, we fix the following terminology |
15
|
|
|
in this file: |
16
|
|
|
|
17
|
|
|
* *network* is the network, typically a recurrent one, which |
18
|
|
|
uses the attention mechanism. |
19
|
|
|
|
20
|
|
|
* The network has *states*. Using this word in plural might seem weird, but |
21
|
|
|
some recurrent networks like :class:`~blocks.bricks.recurrent.LSTM` do |
22
|
|
|
have several states. |
23
|
|
|
|
24
|
|
|
* The big structured input, to which the attention mechanism is applied, |
25
|
|
|
is called the *attended*. When it has variable structure, e.g. a sequence |
26
|
|
|
of variable length, there might be a *mask* associated with it. |
27
|
|
|
|
28
|
|
|
* The information extracted by the attention from the attended is called |
29
|
|
|
*glimpse*, more specifically *glimpses* because there might be a few |
30
|
|
|
pieces of this information. |
31
|
|
|
|
32
|
|
|
Using this terminology, the attention mechanism computes glimpses |
33
|
|
|
given the states of the network and the attended. |
34
|
|
|
|
35
|
|
|
An example: in the machine translation network from [BCB]_ the attended is |
36
|
|
|
a sequence of so-called annotations, that is states of a bidirectional |
37
|
|
|
network that was driven by word embeddings of the source sentence. The |
38
|
|
|
attention mechanism assigns weights to the annotations. The weighted sum of |
39
|
|
|
the annotations is further used by the translation network to predict the |
40
|
|
|
next word of the generated translation. The weights and the weighted sum |
41
|
|
|
are the glimpses. A generalized attention mechanism for this paper is |
42
|
|
|
represented here as :class:`SequenceContentAttention`. |
43
|
|
|
|
44
|
|
|
""" |
45
|
|
|
from abc import ABCMeta, abstractmethod |
46
|
|
|
|
47
|
|
|
from theano import tensor |
48
|
|
|
from six import add_metaclass |
49
|
|
|
|
50
|
|
|
from blocks.bricks import (Brick, Initializable, Sequence, |
51
|
|
|
Feedforward, Linear, Tanh) |
52
|
|
|
from blocks.bricks.base import lazy, application |
53
|
|
|
from blocks.bricks.parallel import Parallel, Distribute |
54
|
|
|
from blocks.bricks.recurrent import recurrent, BaseRecurrent |
55
|
|
|
from blocks.utils import dict_union, dict_subset, pack |
56
|
|
|
|
57
|
|
|
|
58
|
|
|
class AbstractAttention(Brick): |
59
|
|
|
"""The common interface for attention bricks. |
60
|
|
|
|
61
|
|
|
First, see the module-level docstring for terminology. |
62
|
|
|
|
63
|
|
|
A generic attention mechanism functions as follows. Its inputs are the |
64
|
|
|
states of the network and the attended. Given these two it produces |
65
|
|
|
so-called *glimpses*, that is it extracts information from the attended |
66
|
|
|
which is necessary for the network in its current states |
67
|
|
|
|
68
|
|
|
For computational reasons we separate the process described above into |
69
|
|
|
two stages: |
70
|
|
|
|
71
|
|
|
1. The preprocessing stage, :meth:`preprocess`, includes computation |
72
|
|
|
that do not involve the state. Those can be often performed in advance. |
73
|
|
|
The outcome of this stage is called *preprocessed_attended*. |
74
|
|
|
|
75
|
|
|
2. The main stage, :meth:`take_glimpses`, includes all the rest. |
76
|
|
|
|
77
|
|
|
When an attention mechanism is applied sequentially, some glimpses from |
78
|
|
|
the previous step might be necessary to compute the new ones. A |
79
|
|
|
typical example for that is when the focus position from the previous |
80
|
|
|
step is required. In such cases :meth:`take_glimpses` should specify |
81
|
|
|
such need in its interface (its docstring explains how to do that). In |
82
|
|
|
addition :meth:`initial_glimpses` should specify some sensible |
83
|
|
|
initialization for the glimpses to be carried over. |
84
|
|
|
|
85
|
|
|
.. todo:: |
86
|
|
|
|
87
|
|
|
Only single attended is currently allowed. |
88
|
|
|
|
89
|
|
|
:meth:`preprocess` and :meth:`initial_glimpses` might end up |
90
|
|
|
needing masks, which are currently not provided for them. |
91
|
|
|
|
92
|
|
|
Parameters |
93
|
|
|
---------- |
94
|
|
|
state_names : list |
95
|
|
|
The names of the network states. |
96
|
|
|
state_dims : list |
97
|
|
|
The state dimensions corresponding to `state_names`. |
98
|
|
|
attended_dim : int |
99
|
|
|
The dimension of the attended. |
100
|
|
|
|
101
|
|
|
Attributes |
102
|
|
|
---------- |
103
|
|
|
state_names : list |
104
|
|
|
state_dims : list |
105
|
|
|
attended_dim : int |
106
|
|
|
|
107
|
|
|
""" |
108
|
|
|
@lazy(allocation=['state_names', 'state_dims', 'attended_dim']) |
109
|
|
|
def __init__(self, state_names, state_dims, attended_dim, **kwargs): |
110
|
|
|
self.state_names = state_names |
111
|
|
|
self.state_dims = state_dims |
112
|
|
|
self.attended_dim = attended_dim |
113
|
|
|
super(AbstractAttention, self).__init__(**kwargs) |
114
|
|
|
|
115
|
|
|
@application(inputs=['attended'], outputs=['preprocessed_attended']) |
116
|
|
|
def preprocess(self, attended): |
117
|
|
|
"""Perform the preprocessing of the attended. |
118
|
|
|
|
119
|
|
|
Stage 1 of the attention mechanism, see :class:`AbstractAttention` |
120
|
|
|
docstring for an explanation of stages. The default implementation |
121
|
|
|
simply returns attended. |
122
|
|
|
|
123
|
|
|
Parameters |
124
|
|
|
---------- |
125
|
|
|
attended : :class:`~theano.Variable` |
126
|
|
|
The attended. |
127
|
|
|
|
128
|
|
|
Returns |
129
|
|
|
------- |
130
|
|
|
preprocessed_attended : :class:`~theano.Variable` |
131
|
|
|
The preprocessed attended. |
132
|
|
|
|
133
|
|
|
""" |
134
|
|
|
return attended |
135
|
|
|
|
136
|
|
|
@abstractmethod |
137
|
|
|
def take_glimpses(self, attended, preprocessed_attended=None, |
138
|
|
|
attended_mask=None, **kwargs): |
139
|
|
|
r"""Extract glimpses from the attended given the current states. |
140
|
|
|
|
141
|
|
|
Stage 2 of the attention mechanism, see :class:`AbstractAttention` |
142
|
|
|
for an explanation of stages. If `preprocessed_attended` is not |
143
|
|
|
given, should trigger the stage 1. |
144
|
|
|
|
145
|
|
|
This application method *must* declare its inputs and outputs. |
146
|
|
|
The glimpses to be carried over are identified by their presence |
147
|
|
|
in both inputs and outputs list. The attended *must* be the first |
148
|
|
|
input, the preprocessed attended *must* be the second one. |
149
|
|
|
|
150
|
|
|
Parameters |
151
|
|
|
---------- |
152
|
|
|
attended : :class:`~theano.Variable` |
153
|
|
|
The attended. |
154
|
|
|
preprocessed_attended : :class:`~theano.Variable`, optional |
155
|
|
|
The preprocessed attended computed by :meth:`preprocess`. When |
156
|
|
|
not given, :meth:`preprocess` should be called. |
157
|
|
|
attended_mask : :class:`~theano.Variable`, optional |
158
|
|
|
The mask for the attended. This is required in the case of |
159
|
|
|
padded structured output, e.g. when a number of sequences are |
160
|
|
|
force to be the same length. The mask identifies position of |
161
|
|
|
the `attended` that actually contain information. |
162
|
|
|
\*\*kwargs : dict |
163
|
|
|
Includes the states and the glimpses to be carried over from |
164
|
|
|
the previous step in the case when the attention mechanism is |
165
|
|
|
applied sequentially. |
166
|
|
|
|
167
|
|
|
""" |
168
|
|
|
pass |
169
|
|
|
|
170
|
|
|
@abstractmethod |
171
|
|
|
def initial_glimpses(self, batch_size, attended): |
172
|
|
|
"""Return sensible initial values for carried over glimpses. |
173
|
|
|
|
174
|
|
|
Parameters |
175
|
|
|
---------- |
176
|
|
|
batch_size : int or :class:`~theano.Variable` |
177
|
|
|
The batch size. |
178
|
|
|
attended : :class:`~theano.Variable` |
179
|
|
|
The attended. |
180
|
|
|
|
181
|
|
|
Returns |
182
|
|
|
------- |
183
|
|
|
initial_glimpses : list of :class:`~theano.Variable` |
184
|
|
|
The initial values for the requested glimpses. These might |
185
|
|
|
simply consist of zeros or be somehow extracted from |
186
|
|
|
the attended. |
187
|
|
|
|
188
|
|
|
""" |
189
|
|
|
pass |
190
|
|
|
|
191
|
|
|
def get_dim(self, name): |
192
|
|
|
if name in ['attended', 'preprocessed_attended']: |
193
|
|
|
return self.attended_dim |
194
|
|
|
if name in ['attended_mask']: |
195
|
|
|
return 0 |
196
|
|
|
return super(AbstractAttention, self).get_dim(name) |
197
|
|
|
|
198
|
|
|
|
199
|
|
|
class GenericSequenceAttention(AbstractAttention): |
200
|
|
|
"""Logic common for sequence attention mechanisms.""" |
201
|
|
|
@application |
202
|
|
|
def compute_weights(self, energies, attended_mask): |
203
|
|
|
"""Compute weights from energies in softmax-like fashion. |
204
|
|
|
|
205
|
|
|
.. todo :: |
206
|
|
|
|
207
|
|
|
Use :class:`~blocks.bricks.Softmax`. |
208
|
|
|
|
209
|
|
|
Parameters |
210
|
|
|
---------- |
211
|
|
|
energies : :class:`~theano.Variable` |
212
|
|
|
The energies. Must be of the same shape as the mask. |
213
|
|
|
attended_mask : :class:`~theano.Variable` |
214
|
|
|
The mask for the attended. The index in the sequence must be |
215
|
|
|
the first dimension. |
216
|
|
|
|
217
|
|
|
Returns |
218
|
|
|
------- |
219
|
|
|
weights : :class:`~theano.Variable` |
220
|
|
|
Summing to 1 non-negative weights of the same shape |
221
|
|
|
as `energies`. |
222
|
|
|
|
223
|
|
|
""" |
224
|
|
|
# Stabilize energies first and then exponentiate |
225
|
|
|
energies = energies - energies.max(axis=0) |
226
|
|
|
unnormalized_weights = tensor.exp(energies) |
227
|
|
|
if attended_mask: |
228
|
|
|
unnormalized_weights *= attended_mask |
229
|
|
|
|
230
|
|
|
# If mask consists of all zeros use 1 as the normalization coefficient |
231
|
|
|
normalization = (unnormalized_weights.sum(axis=0) + |
232
|
|
|
tensor.all(1 - attended_mask, axis=0)) |
233
|
|
|
return unnormalized_weights / normalization |
234
|
|
|
|
235
|
|
|
@application |
236
|
|
|
def compute_weighted_averages(self, weights, attended): |
237
|
|
|
"""Compute weighted averages of the attended sequence vectors. |
238
|
|
|
|
239
|
|
|
Parameters |
240
|
|
|
---------- |
241
|
|
|
weights : :class:`~theano.Variable` |
242
|
|
|
The weights. The shape must be equal to the attended shape |
243
|
|
|
without the last dimension. |
244
|
|
|
attended : :class:`~theano.Variable` |
245
|
|
|
The attended. The index in the sequence must be the first |
246
|
|
|
dimension. |
247
|
|
|
|
248
|
|
|
Returns |
249
|
|
|
------- |
250
|
|
|
weighted_averages : :class:`~theano.Variable` |
251
|
|
|
The weighted averages of the attended elements. The shape |
252
|
|
|
is equal to the attended shape with the first dimension |
253
|
|
|
dropped. |
254
|
|
|
|
255
|
|
|
""" |
256
|
|
|
return (tensor.shape_padright(weights) * attended).sum(axis=0) |
257
|
|
|
|
258
|
|
|
|
259
|
|
|
class SequenceContentAttention(GenericSequenceAttention, Initializable): |
260
|
|
|
"""Attention mechanism that looks for relevant content in a sequence. |
261
|
|
|
|
262
|
|
|
This is the attention mechanism used in [BCB]_. The idea in a nutshell: |
263
|
|
|
|
264
|
|
|
1. The states and the sequence are transformed independently, |
265
|
|
|
|
266
|
|
|
2. The transformed states are summed with every transformed sequence |
267
|
|
|
element to obtain *match vectors*, |
268
|
|
|
|
269
|
|
|
3. A match vector is transformed into a single number interpreted as |
270
|
|
|
*energy*, |
271
|
|
|
|
272
|
|
|
4. Energies are normalized in softmax-like fashion. The resulting |
273
|
|
|
summing to one weights are called *attention weights*, |
274
|
|
|
|
275
|
|
|
5. Weighted average of the sequence elements with attention weights |
276
|
|
|
is computed. |
277
|
|
|
|
278
|
|
|
In terms of the :class:`AbstractAttention` documentation, the sequence |
279
|
|
|
is the attended. The weighted averages from 5 and the attention |
280
|
|
|
weights from 4 form the set of glimpses produced by this attention |
281
|
|
|
mechanism. |
282
|
|
|
|
283
|
|
|
Parameters |
284
|
|
|
---------- |
285
|
|
|
state_names : list of str |
286
|
|
|
The names of the network states. |
287
|
|
|
attended_dim : int |
288
|
|
|
The dimension of the sequence elements. |
289
|
|
|
match_dim : int |
290
|
|
|
The dimension of the match vector. |
291
|
|
|
state_transformer : :class:`~.bricks.Brick` |
292
|
|
|
A prototype for state transformations. If ``None``, |
293
|
|
|
a linear transformation is used. |
294
|
|
|
attended_transformer : :class:`.Feedforward` |
295
|
|
|
The transformation to be applied to the sequence. If ``None`` an |
296
|
|
|
affine transformation is used. |
297
|
|
|
energy_computer : :class:`.Feedforward` |
298
|
|
|
Computes energy from the match vector. If ``None``, an affine |
299
|
|
|
transformations preceeded by :math:`tanh` is used. |
300
|
|
|
|
301
|
|
|
Notes |
302
|
|
|
----- |
303
|
|
|
See :class:`.Initializable` for initialization parameters. |
304
|
|
|
|
305
|
|
|
.. [BCB] Dzmitry Bahdanau, Kyunghyun Cho and Yoshua Bengio. Neural |
306
|
|
|
Machine Translation by Jointly Learning to Align and Translate. |
307
|
|
|
|
308
|
|
|
""" |
309
|
|
|
@lazy(allocation=['match_dim']) |
310
|
|
|
def __init__(self, match_dim, state_transformer=None, |
311
|
|
|
attended_transformer=None, energy_computer=None, **kwargs): |
312
|
|
|
if not state_transformer: |
313
|
|
|
state_transformer = Linear(use_bias=False) |
314
|
|
|
self.match_dim = match_dim |
315
|
|
|
self.state_transformer = state_transformer |
316
|
|
|
|
317
|
|
|
self.state_transformers = Parallel(input_names=kwargs['state_names'], |
318
|
|
|
prototype=state_transformer, |
319
|
|
|
name="state_trans") |
320
|
|
|
if not attended_transformer: |
321
|
|
|
attended_transformer = Linear(name="preprocess") |
322
|
|
|
if not energy_computer: |
323
|
|
|
energy_computer = ShallowEnergyComputer(name="energy_comp") |
324
|
|
|
self.attended_transformer = attended_transformer |
325
|
|
|
self.energy_computer = energy_computer |
326
|
|
|
|
327
|
|
|
children = [self.state_transformers, attended_transformer, |
328
|
|
|
energy_computer] |
329
|
|
|
kwargs.setdefault('children', []).extend(children) |
330
|
|
|
super(SequenceContentAttention, self).__init__(**kwargs) |
331
|
|
|
|
332
|
|
|
def _push_allocation_config(self): |
333
|
|
|
self.state_transformers.input_dims = self.state_dims |
334
|
|
|
self.state_transformers.output_dims = [self.match_dim |
335
|
|
|
for name in self.state_names] |
336
|
|
|
self.attended_transformer.input_dim = self.attended_dim |
337
|
|
|
self.attended_transformer.output_dim = self.match_dim |
338
|
|
|
self.energy_computer.input_dim = self.match_dim |
339
|
|
|
self.energy_computer.output_dim = 1 |
340
|
|
|
|
341
|
|
|
@application |
342
|
|
|
def compute_energies(self, attended, preprocessed_attended, states): |
343
|
|
|
if not preprocessed_attended: |
344
|
|
|
preprocessed_attended = self.preprocess(attended) |
345
|
|
|
transformed_states = self.state_transformers.apply(as_dict=True, |
346
|
|
|
**states) |
347
|
|
|
# Broadcasting of transformed states should be done automatically |
348
|
|
|
match_vectors = sum(transformed_states.values(), |
349
|
|
|
preprocessed_attended) |
350
|
|
|
energies = self.energy_computer.apply(match_vectors).reshape( |
351
|
|
|
match_vectors.shape[:-1], ndim=match_vectors.ndim - 1) |
352
|
|
|
return energies |
353
|
|
|
|
354
|
|
|
@application(outputs=['weighted_averages', 'weights']) |
355
|
|
|
def take_glimpses(self, attended, preprocessed_attended=None, |
356
|
|
|
attended_mask=None, **states): |
357
|
|
|
r"""Compute attention weights and produce glimpses. |
358
|
|
|
|
359
|
|
|
Parameters |
360
|
|
|
---------- |
361
|
|
|
attended : :class:`~tensor.TensorVariable` |
362
|
|
|
The sequence, time is the 1-st dimension. |
363
|
|
|
preprocessed_attended : :class:`~tensor.TensorVariable` |
364
|
|
|
The preprocessed sequence. If ``None``, is computed by calling |
365
|
|
|
:meth:`preprocess`. |
366
|
|
|
attended_mask : :class:`~tensor.TensorVariable` |
367
|
|
|
A 0/1 mask specifying available data. 0 means that the |
368
|
|
|
corresponding sequence element is fake. |
369
|
|
|
\*\*states |
370
|
|
|
The states of the network. |
371
|
|
|
|
372
|
|
|
Returns |
373
|
|
|
------- |
374
|
|
|
weighted_averages : :class:`~theano.Variable` |
375
|
|
|
Linear combinations of sequence elements with the attention |
376
|
|
|
weights. |
377
|
|
|
weights : :class:`~theano.Variable` |
378
|
|
|
The attention weights. The first dimension is batch, the second |
379
|
|
|
is time. |
380
|
|
|
|
381
|
|
|
""" |
382
|
|
|
energies = self.compute_energies(attended, preprocessed_attended, |
383
|
|
|
states) |
384
|
|
|
weights = self.compute_weights(energies, attended_mask) |
385
|
|
|
weighted_averages = self.compute_weighted_averages(weights, attended) |
386
|
|
|
return weighted_averages, weights.T |
387
|
|
|
|
388
|
|
|
@take_glimpses.property('inputs') |
389
|
|
|
def take_glimpses_inputs(self): |
390
|
|
|
return (['attended', 'preprocessed_attended', 'attended_mask'] + |
391
|
|
|
self.state_names) |
392
|
|
|
|
393
|
|
|
@application(outputs=['weighted_averages', 'weights']) |
394
|
|
|
def initial_glimpses(self, batch_size, attended): |
395
|
|
|
return [tensor.zeros((batch_size, self.attended_dim)), |
396
|
|
|
tensor.zeros((batch_size, attended.shape[0]))] |
397
|
|
|
|
398
|
|
|
@application(inputs=['attended'], outputs=['preprocessed_attended']) |
399
|
|
|
def preprocess(self, attended): |
400
|
|
|
"""Preprocess the sequence for computing attention weights. |
401
|
|
|
|
402
|
|
|
Parameters |
403
|
|
|
---------- |
404
|
|
|
attended : :class:`~tensor.TensorVariable` |
405
|
|
|
The attended sequence, time is the 1-st dimension. |
406
|
|
|
|
407
|
|
|
""" |
408
|
|
|
return self.attended_transformer.apply(attended) |
409
|
|
|
|
410
|
|
|
def get_dim(self, name): |
411
|
|
|
if name in ['weighted_averages']: |
412
|
|
|
return self.attended_dim |
413
|
|
|
if name in ['weights']: |
414
|
|
|
return 0 |
415
|
|
|
return super(SequenceContentAttention, self).get_dim(name) |
416
|
|
|
|
417
|
|
|
|
418
|
|
|
class ShallowEnergyComputer(Sequence, Initializable, Feedforward): |
419
|
|
|
"""A simple energy computer: first tanh, then weighted sum. |
420
|
|
|
|
421
|
|
|
Parameters |
422
|
|
|
---------- |
423
|
|
|
use_bias : bool, optional |
424
|
|
|
Whether a bias should be added to the energies. Does not change |
425
|
|
|
anything if softmax normalization is used to produce the attention |
426
|
|
|
weights, but might be useful when e.g. spherical softmax is used. |
427
|
|
|
|
428
|
|
|
""" |
429
|
|
|
@lazy() |
430
|
|
|
def __init__(self, use_bias=False, **kwargs): |
431
|
|
|
super(ShallowEnergyComputer, self).__init__( |
432
|
|
|
[Tanh().apply, Linear(use_bias=use_bias).apply], **kwargs) |
433
|
|
|
|
434
|
|
|
@property |
435
|
|
|
def input_dim(self): |
436
|
|
|
return self.children[1].input_dim |
437
|
|
|
|
438
|
|
|
@input_dim.setter |
439
|
|
|
def input_dim(self, value): |
440
|
|
|
self.children[1].input_dim = value |
441
|
|
|
|
442
|
|
|
@property |
443
|
|
|
def output_dim(self): |
444
|
|
|
return self.children[1].output_dim |
445
|
|
|
|
446
|
|
|
@output_dim.setter |
447
|
|
|
def output_dim(self, value): |
448
|
|
|
self.children[1].output_dim = value |
449
|
|
|
|
450
|
|
|
|
451
|
|
|
@add_metaclass(ABCMeta) |
452
|
|
|
class AbstractAttentionRecurrent(BaseRecurrent): |
453
|
|
|
"""The interface for attention-equipped recurrent transitions. |
454
|
|
|
|
455
|
|
|
When a recurrent network is equipped with an attention mechanism its |
456
|
|
|
transition typically consists of two steps: (1) the glimpses are taken |
457
|
|
|
by the attention mechanism and (2) the next states are computed using |
458
|
|
|
the current states and the glimpses. It is required for certain |
459
|
|
|
usecases (such as sequence generator) that apart from a do-it-all |
460
|
|
|
recurrent application method interfaces for the first step and |
461
|
|
|
the second steps of the transition are provided. |
462
|
|
|
|
463
|
|
|
""" |
464
|
|
|
@abstractmethod |
465
|
|
|
def apply(self, **kwargs): |
466
|
|
|
"""Compute next states taking glimpses on the way.""" |
467
|
|
|
pass |
468
|
|
|
|
469
|
|
|
@abstractmethod |
470
|
|
|
def take_glimpses(self, **kwargs): |
471
|
|
|
"""Compute glimpses given the current states.""" |
472
|
|
|
pass |
473
|
|
|
|
474
|
|
|
@abstractmethod |
475
|
|
|
def compute_states(self, **kwargs): |
476
|
|
|
"""Compute next states given current states and glimpses.""" |
477
|
|
|
pass |
478
|
|
|
|
479
|
|
|
|
480
|
|
|
class AttentionRecurrent(AbstractAttentionRecurrent, Initializable): |
481
|
|
|
"""Combines an attention mechanism and a recurrent transition. |
482
|
|
|
|
483
|
|
|
This brick equips a recurrent transition with an attention mechanism. |
484
|
|
|
In order to do this two more contexts are added: one to be attended and |
485
|
|
|
a mask for it. It is also possible to use the contexts of the given |
486
|
|
|
recurrent transition for these purposes and not add any new ones, |
487
|
|
|
see `add_context` parameter. |
488
|
|
|
|
489
|
|
|
At the beginning of each step attention mechanism produces glimpses; |
490
|
|
|
these glimpses together with the current states are used to compute the |
491
|
|
|
next state and finish the transition. In some cases glimpses from the |
492
|
|
|
previous steps are also necessary for the attention mechanism, e.g. |
493
|
|
|
in order to focus on an area close to the one from the previous step. |
494
|
|
|
This is also supported: such glimpses become states of the new |
495
|
|
|
transition. |
496
|
|
|
|
497
|
|
|
To let the user control the way glimpses are used, this brick also |
498
|
|
|
takes a "distribute" brick as parameter that distributes the |
499
|
|
|
information from glimpses across the sequential inputs of the wrapped |
500
|
|
|
recurrent transition. |
501
|
|
|
|
502
|
|
|
Parameters |
503
|
|
|
---------- |
504
|
|
|
transition : :class:`.BaseRecurrent` |
505
|
|
|
The recurrent transition. |
506
|
|
|
attention : :class:`~.bricks.Brick` |
507
|
|
|
The attention mechanism. |
508
|
|
|
distribute : :class:`~.bricks.Brick`, optional |
509
|
|
|
Distributes the information from glimpses across the input |
510
|
|
|
sequences of the transition. By default a :class:`.Distribute` is |
511
|
|
|
used, and those inputs containing the "mask" substring in their |
512
|
|
|
name are not affected. |
513
|
|
|
add_contexts : bool, optional |
514
|
|
|
If ``True``, new contexts for the attended and the attended mask |
515
|
|
|
are added to this transition, otherwise existing contexts of the |
516
|
|
|
wrapped transition are used. ``True`` by default. |
517
|
|
|
attended_name : str |
518
|
|
|
The name of the attended context. If ``None``, "attended" |
519
|
|
|
or the first context of the recurrent transition is used |
520
|
|
|
depending on the value of `add_contents` flag. |
521
|
|
|
attended_mask_name : str |
522
|
|
|
The name of the mask for the attended context. If ``None``, |
523
|
|
|
"attended_mask" or the second context of the recurrent transition |
524
|
|
|
is used depending on the value of `add_contents` flag. |
525
|
|
|
|
526
|
|
|
Notes |
527
|
|
|
----- |
528
|
|
|
See :class:`.Initializable` for initialization parameters. |
529
|
|
|
|
530
|
|
|
Wrapping your recurrent brick with this class makes all the |
531
|
|
|
states mandatory. If you feel this is a limitation for you, try |
532
|
|
|
to make it better! This restriction does not apply to sequences |
533
|
|
|
and contexts: those keep being as optional as they were for |
534
|
|
|
your brick. |
535
|
|
|
|
536
|
|
|
Those coming to Blocks from Groundhog might recognize that this is |
537
|
|
|
a `RecurrentLayerWithSearch`, but on steroids :) |
538
|
|
|
|
539
|
|
|
""" |
540
|
|
|
def __init__(self, transition, attention, distribute=None, |
541
|
|
|
add_contexts=True, |
542
|
|
|
attended_name=None, attended_mask_name=None, |
543
|
|
|
**kwargs): |
544
|
|
|
self._sequence_names = list(transition.apply.sequences) |
545
|
|
|
self._state_names = list(transition.apply.states) |
546
|
|
|
self._context_names = list(transition.apply.contexts) |
547
|
|
|
if add_contexts: |
548
|
|
|
if not attended_name: |
549
|
|
|
attended_name = 'attended' |
550
|
|
|
if not attended_mask_name: |
551
|
|
|
attended_mask_name = 'attended_mask' |
552
|
|
|
self._context_names += [attended_name, attended_mask_name] |
553
|
|
|
else: |
554
|
|
|
attended_name = self._context_names[0] |
555
|
|
|
attended_mask_name = self._context_names[1] |
556
|
|
|
if not distribute: |
557
|
|
|
normal_inputs = [name for name in self._sequence_names |
558
|
|
|
if 'mask' not in name] |
559
|
|
|
distribute = Distribute(normal_inputs, |
560
|
|
|
attention.take_glimpses.outputs[0]) |
561
|
|
|
|
562
|
|
|
self.transition = transition |
563
|
|
|
self.attention = attention |
564
|
|
|
self.distribute = distribute |
565
|
|
|
self.add_contexts = add_contexts |
566
|
|
|
self.attended_name = attended_name |
567
|
|
|
self.attended_mask_name = attended_mask_name |
568
|
|
|
|
569
|
|
|
self.preprocessed_attended_name = "preprocessed_" + self.attended_name |
570
|
|
|
|
571
|
|
|
self._glimpse_names = self.attention.take_glimpses.outputs |
572
|
|
|
# We need to determine which glimpses are fed back. |
573
|
|
|
# Currently we extract it from `take_glimpses` signature. |
574
|
|
|
self.previous_glimpses_needed = [ |
575
|
|
|
name for name in self._glimpse_names |
576
|
|
|
if name in self.attention.take_glimpses.inputs] |
577
|
|
|
|
578
|
|
|
children = [self.transition, self.attention, self.distribute] |
579
|
|
|
kwargs.setdefault('children', []).extend(children) |
580
|
|
|
super(AttentionRecurrent, self).__init__(**kwargs) |
581
|
|
|
|
582
|
|
|
def _push_allocation_config(self): |
583
|
|
|
self.attention.state_dims = self.transition.get_dims( |
584
|
|
|
self.attention.state_names) |
585
|
|
|
self.attention.attended_dim = self.get_dim(self.attended_name) |
586
|
|
|
self.distribute.source_dim = self.attention.get_dim( |
587
|
|
|
self.distribute.source_name) |
588
|
|
|
self.distribute.target_dims = self.transition.get_dims( |
589
|
|
|
self.distribute.target_names) |
590
|
|
|
|
591
|
|
|
@application |
592
|
|
|
def take_glimpses(self, **kwargs): |
593
|
|
|
r"""Compute glimpses with the attention mechanism. |
594
|
|
|
|
595
|
|
|
A thin wrapper over `self.attention.take_glimpses`: takes care |
596
|
|
|
of choosing and renaming the necessary arguments. |
597
|
|
|
|
598
|
|
|
Parameters |
599
|
|
|
---------- |
600
|
|
|
\*\*kwargs |
601
|
|
|
Must contain the attended, previous step states and glimpses. |
602
|
|
|
Can optionaly contain the attended mask and the preprocessed |
603
|
|
|
attended. |
604
|
|
|
|
605
|
|
|
Returns |
606
|
|
|
------- |
607
|
|
|
glimpses : list of :class:`~tensor.TensorVariable` |
608
|
|
|
Current step glimpses. |
609
|
|
|
|
610
|
|
|
""" |
611
|
|
|
states = dict_subset(kwargs, self._state_names, pop=True) |
612
|
|
|
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) |
613
|
|
|
glimpses_needed = dict_subset(glimpses, self.previous_glimpses_needed) |
614
|
|
|
result = self.attention.take_glimpses( |
615
|
|
|
kwargs.pop(self.attended_name), |
616
|
|
|
kwargs.pop(self.preprocessed_attended_name, None), |
617
|
|
|
kwargs.pop(self.attended_mask_name, None), |
618
|
|
|
**dict_union(states, glimpses_needed)) |
619
|
|
|
# At this point kwargs may contain additional items. |
620
|
|
|
# e.g. AttentionRecurrent.transition.apply.contexts |
621
|
|
|
return result |
622
|
|
|
|
623
|
|
|
@take_glimpses.property('outputs') |
624
|
|
|
def take_glimpses_outputs(self): |
625
|
|
|
return self._glimpse_names |
626
|
|
|
|
627
|
|
|
@application |
628
|
|
|
def compute_states(self, **kwargs): |
629
|
|
|
r"""Compute current states when glimpses have already been computed. |
630
|
|
|
|
631
|
|
|
Combines an application of the `distribute` that alter the |
632
|
|
|
sequential inputs of the wrapped transition and an application of |
633
|
|
|
the wrapped transition. All unknown keyword arguments go to |
634
|
|
|
the wrapped transition. |
635
|
|
|
|
636
|
|
|
Parameters |
637
|
|
|
---------- |
638
|
|
|
\*\*kwargs |
639
|
|
|
Should contain everything what `self.transition` needs |
640
|
|
|
and in addition the current glimpses. |
641
|
|
|
|
642
|
|
|
Returns |
643
|
|
|
------- |
644
|
|
|
current_states : list of :class:`~tensor.TensorVariable` |
645
|
|
|
Current states computed by `self.transition`. |
646
|
|
|
|
647
|
|
|
""" |
648
|
|
|
# make sure we are not popping the mask |
649
|
|
|
normal_inputs = [name for name in self._sequence_names |
650
|
|
|
if 'mask' not in name] |
651
|
|
|
sequences = dict_subset(kwargs, normal_inputs, pop=True) |
652
|
|
|
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) |
653
|
|
|
if self.add_contexts: |
654
|
|
|
kwargs.pop(self.attended_name) |
655
|
|
|
# attended_mask_name can be optional |
656
|
|
|
kwargs.pop(self.attended_mask_name, None) |
657
|
|
|
|
658
|
|
|
sequences.update(self.distribute.apply( |
659
|
|
|
as_dict=True, **dict_subset(dict_union(sequences, glimpses), |
660
|
|
|
self.distribute.apply.inputs))) |
661
|
|
|
current_states = self.transition.apply( |
662
|
|
|
iterate=False, as_list=True, |
663
|
|
|
**dict_union(sequences, kwargs)) |
664
|
|
|
return current_states |
665
|
|
|
|
666
|
|
|
@compute_states.property('outputs') |
667
|
|
|
def compute_states_outputs(self): |
668
|
|
|
return self._state_names |
669
|
|
|
|
670
|
|
|
@recurrent |
671
|
|
|
def do_apply(self, **kwargs): |
672
|
|
|
r"""Process a sequence attending the attended context every step. |
673
|
|
|
|
674
|
|
|
In addition to the original sequence this method also requires |
675
|
|
|
its preprocessed version, the one computed by the `preprocess` |
676
|
|
|
method of the attention mechanism. Unknown keyword arguments |
677
|
|
|
are passed to the wrapped transition. |
678
|
|
|
|
679
|
|
|
Parameters |
680
|
|
|
---------- |
681
|
|
|
\*\*kwargs |
682
|
|
|
Should contain current inputs, previous step states, contexts, |
683
|
|
|
the preprocessed attended context, previous step glimpses. |
684
|
|
|
|
685
|
|
|
Returns |
686
|
|
|
------- |
687
|
|
|
outputs : list of :class:`~tensor.TensorVariable` |
688
|
|
|
The current step states and glimpses. |
689
|
|
|
|
690
|
|
|
""" |
691
|
|
|
attended = kwargs[self.attended_name] |
692
|
|
|
preprocessed_attended = kwargs.pop(self.preprocessed_attended_name) |
693
|
|
|
attended_mask = kwargs.get(self.attended_mask_name) |
694
|
|
|
sequences = dict_subset(kwargs, self._sequence_names, pop=True, |
695
|
|
|
must_have=False) |
696
|
|
|
states = dict_subset(kwargs, self._state_names, pop=True) |
697
|
|
|
glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) |
698
|
|
|
|
699
|
|
|
current_glimpses = self.take_glimpses( |
700
|
|
|
as_dict=True, |
701
|
|
|
**dict_union( |
702
|
|
|
states, glimpses, |
703
|
|
|
{self.attended_name: attended, |
704
|
|
|
self.attended_mask_name: attended_mask, |
705
|
|
|
self.preprocessed_attended_name: preprocessed_attended})) |
706
|
|
|
current_states = self.compute_states( |
707
|
|
|
as_list=True, |
708
|
|
|
**dict_union(sequences, states, current_glimpses, kwargs)) |
709
|
|
|
return current_states + list(current_glimpses.values()) |
710
|
|
|
|
711
|
|
|
@do_apply.property('sequences') |
712
|
|
|
def do_apply_sequences(self): |
713
|
|
|
return self._sequence_names |
714
|
|
|
|
715
|
|
|
@do_apply.property('contexts') |
716
|
|
|
def do_apply_contexts(self): |
717
|
|
|
return self._context_names + [self.preprocessed_attended_name] |
718
|
|
|
|
719
|
|
|
@do_apply.property('states') |
720
|
|
|
def do_apply_states(self): |
721
|
|
|
return self._state_names + self._glimpse_names |
722
|
|
|
|
723
|
|
|
@do_apply.property('outputs') |
724
|
|
|
def do_apply_outputs(self): |
725
|
|
|
return self._state_names + self._glimpse_names |
726
|
|
|
|
727
|
|
|
@application |
728
|
|
|
def apply(self, **kwargs): |
729
|
|
|
"""Preprocess a sequence attending the attended context at every step. |
730
|
|
|
|
731
|
|
|
Preprocesses the attended context and runs :meth:`do_apply`. See |
732
|
|
|
:meth:`do_apply` documentation for further information. |
733
|
|
|
|
734
|
|
|
""" |
735
|
|
|
preprocessed_attended = self.attention.preprocess( |
736
|
|
|
kwargs[self.attended_name]) |
737
|
|
|
return self.do_apply( |
738
|
|
|
**dict_union(kwargs, |
739
|
|
|
{self.preprocessed_attended_name: |
740
|
|
|
preprocessed_attended})) |
741
|
|
|
|
742
|
|
|
@apply.delegate |
743
|
|
|
def apply_delegate(self): |
744
|
|
|
# TODO: Nice interface for this trick? |
745
|
|
|
return self.do_apply.__get__(self, None) |
746
|
|
|
|
747
|
|
|
@apply.property('contexts') |
748
|
|
|
def apply_contexts(self): |
749
|
|
|
return self._context_names |
750
|
|
|
|
751
|
|
|
@application |
752
|
|
|
def initial_states(self, batch_size, **kwargs): |
753
|
|
|
return (pack(self.transition.initial_states( |
754
|
|
|
batch_size, **kwargs)) + |
755
|
|
|
pack(self.attention.initial_glimpses( |
756
|
|
|
batch_size, kwargs[self.attended_name]))) |
757
|
|
|
|
758
|
|
|
@initial_states.property('outputs') |
759
|
|
|
def initial_states_outputs(self): |
760
|
|
|
return self.do_apply.states |
761
|
|
|
|
762
|
|
|
def get_dim(self, name): |
763
|
|
|
if name in self._glimpse_names: |
764
|
|
|
return self.attention.get_dim(name) |
765
|
|
|
if name == self.preprocessed_attended_name: |
766
|
|
|
(original_name,) = self.attention.preprocess.outputs |
767
|
|
|
return self.attention.get_dim(original_name) |
768
|
|
|
if self.add_contexts: |
769
|
|
|
if name == self.attended_name: |
770
|
|
|
return self.attention.get_dim( |
771
|
|
|
self.attention.take_glimpses.inputs[0]) |
772
|
|
|
if name == self.attended_mask_name: |
773
|
|
|
return 0 |
774
|
|
|
return self.transition.get_dim(name) |
775
|
|
|
|