AttentionRecurrent   A
last analyzed

Complexity

Total Complexity 33

Size/Duplication

Total Lines 295
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
dl 0
loc 295
rs 9.3999
c 0
b 0
f 0
wmc 33

17 Methods

Rating   Name   Duplication   Size   Complexity  
A initial_states() 0 6 1
B compute_states() 0 38 4
B do_apply() 0 40 1
A compute_states_outputs() 0 3 1
A apply_delegate() 0 4 1
A do_apply_sequences() 0 3 1
F __init__() 0 41 9
A initial_states_outputs() 0 3 1
A apply_contexts() 0 3 1
A do_apply_outputs() 0 3 1
B get_dim() 0 13 6
A apply() 0 14 1
A _push_allocation_config() 0 8 1
A do_apply_contexts() 0 3 1
A take_glimpses_outputs() 0 3 1
A do_apply_states() 0 3 1
B take_glimpses() 0 31 1
1
"""Attention mechanisms.
2
3
This module defines the interface of attention mechanisms and a few
4
concrete implementations. For a gentle introduction and usage examples see
5
the tutorial TODO.
6
7
An attention mechanism decides to what part of the input to pay attention.
8
It is typically used as a component of a recurrent network, though one can
9
imagine it used in other conditions as well. When the input is big and has
10
certain structure, for instance when it is sequence or an image, an
11
attention mechanism can be applied to extract only information which is
12
relevant for the network in its current state.
13
14
For the purpose of documentation clarity, we fix the following terminology
15
in this file:
16
17
* *network* is the network, typically a recurrent one, which
18
  uses the attention mechanism.
19
20
* The network has *states*. Using this word in plural might seem weird, but
21
  some recurrent networks like :class:`~blocks.bricks.recurrent.LSTM` do
22
  have several states.
23
24
* The big structured input, to which the attention mechanism is applied,
25
  is called the *attended*. When it has variable structure, e.g. a sequence
26
  of variable length, there might be a *mask* associated with it.
27
28
* The information extracted by the attention from the attended is called
29
  *glimpse*, more specifically *glimpses* because there might be a few
30
  pieces of this information.
31
32
Using this terminology, the attention mechanism computes glimpses
33
given the states of the network and the attended.
34
35
An example: in the machine translation network from [BCB]_ the attended is
36
a sequence of so-called annotations, that is states of a bidirectional
37
network that was driven by word embeddings of the source sentence. The
38
attention mechanism assigns weights to the annotations. The weighted sum of
39
the annotations is further used by the translation network to predict the
40
next word of the generated translation. The weights and the weighted sum
41
are the glimpses. A generalized attention mechanism for this paper is
42
represented here as :class:`SequenceContentAttention`.
43
44
"""
45
from abc import ABCMeta, abstractmethod
46
47
from theano import tensor
48
from six import add_metaclass
49
50
from blocks.bricks import (Brick, Initializable, Sequence,
51
                           Feedforward, Linear, Tanh)
52
from blocks.bricks.base import lazy, application
53
from blocks.bricks.parallel import Parallel, Distribute
54
from blocks.bricks.recurrent import recurrent, BaseRecurrent
55
from blocks.utils import dict_union, dict_subset, pack
56
57
58
class AbstractAttention(Brick):
59
    """The common interface for attention bricks.
60
61
    First, see the module-level docstring for terminology.
62
63
    A generic attention mechanism functions as follows. Its inputs are the
64
    states of the network and the attended. Given these two it produces
65
    so-called *glimpses*, that is it extracts information from the attended
66
    which is necessary for the network in its current states
67
68
    For computational reasons we separate the process described above into
69
    two stages:
70
71
    1. The preprocessing stage, :meth:`preprocess`, includes computation
72
    that do not involve the state. Those can be often performed in advance.
73
    The outcome of this stage is called *preprocessed_attended*.
74
75
    2. The main stage, :meth:`take_glimpses`, includes all the rest.
76
77
    When an attention mechanism is applied sequentially, some glimpses from
78
    the previous step might be necessary to compute the new ones.  A
79
    typical example for that is when the focus position from the previous
80
    step is required. In such cases :meth:`take_glimpses` should specify
81
    such need in its interface (its docstring explains how to do that). In
82
    addition :meth:`initial_glimpses` should specify some sensible
83
    initialization for the glimpses to be carried over.
84
85
    .. todo::
86
87
        Only single attended is currently allowed.
88
89
        :meth:`preprocess` and :meth:`initial_glimpses` might end up
90
        needing masks, which are currently not provided for them.
91
92
    Parameters
93
    ----------
94
    state_names : list
95
        The names of the network states.
96
    state_dims : list
97
        The state dimensions corresponding to `state_names`.
98
    attended_dim : int
99
        The dimension of the attended.
100
101
    Attributes
102
    ----------
103
    state_names : list
104
    state_dims : list
105
    attended_dim : int
106
107
    """
108
    @lazy(allocation=['state_names', 'state_dims', 'attended_dim'])
109
    def __init__(self, state_names, state_dims, attended_dim, **kwargs):
110
        self.state_names = state_names
111
        self.state_dims = state_dims
112
        self.attended_dim = attended_dim
113
        super(AbstractAttention, self).__init__(**kwargs)
114
115
    @application(inputs=['attended'], outputs=['preprocessed_attended'])
116
    def preprocess(self, attended):
117
        """Perform the preprocessing of the attended.
118
119
        Stage 1 of the attention mechanism, see :class:`AbstractAttention`
120
        docstring for an explanation of stages. The default implementation
121
        simply returns attended.
122
123
        Parameters
124
        ----------
125
        attended : :class:`~theano.Variable`
126
            The attended.
127
128
        Returns
129
        -------
130
        preprocessed_attended : :class:`~theano.Variable`
131
            The preprocessed attended.
132
133
        """
134
        return attended
135
136
    @abstractmethod
137
    def take_glimpses(self, attended, preprocessed_attended=None,
138
                      attended_mask=None, **kwargs):
139
        r"""Extract glimpses from the attended given the current states.
140
141
        Stage 2 of the attention mechanism, see :class:`AbstractAttention`
142
        for an explanation of stages. If `preprocessed_attended` is not
143
        given, should trigger the stage 1.
144
145
        This application method *must* declare its inputs and outputs.
146
        The glimpses to be carried over are identified by their presence
147
        in both inputs and outputs list. The attended *must* be the first
148
        input, the preprocessed attended *must* be the second one.
149
150
        Parameters
151
        ----------
152
        attended : :class:`~theano.Variable`
153
            The attended.
154
        preprocessed_attended : :class:`~theano.Variable`, optional
155
            The preprocessed attended computed by :meth:`preprocess`.  When
156
            not given, :meth:`preprocess` should be called.
157
        attended_mask : :class:`~theano.Variable`, optional
158
            The mask for the attended. This is required in the case of
159
            padded structured output, e.g. when a number of sequences are
160
            force to be the same length. The mask identifies position of
161
            the `attended` that actually contain information.
162
        \*\*kwargs : dict
163
            Includes the states and the glimpses to be carried over from
164
            the previous step in the case when the attention mechanism is
165
            applied sequentially.
166
167
        """
168
        pass
169
170
    @abstractmethod
171
    def initial_glimpses(self, batch_size, attended):
172
        """Return sensible initial values for carried over glimpses.
173
174
        Parameters
175
        ----------
176
        batch_size : int or :class:`~theano.Variable`
177
            The batch size.
178
        attended : :class:`~theano.Variable`
179
            The attended.
180
181
        Returns
182
        -------
183
        initial_glimpses : list of :class:`~theano.Variable`
184
            The initial values for the requested glimpses. These might
185
            simply consist of zeros or be somehow extracted from
186
            the attended.
187
188
        """
189
        pass
190
191
    def get_dim(self, name):
192
        if name in ['attended', 'preprocessed_attended']:
193
            return self.attended_dim
194
        if name in ['attended_mask']:
195
            return 0
196
        return super(AbstractAttention, self).get_dim(name)
197
198
199
class GenericSequenceAttention(AbstractAttention):
200
    """Logic common for sequence attention mechanisms."""
201
    @application
202
    def compute_weights(self, energies, attended_mask):
203
        """Compute weights from energies in softmax-like fashion.
204
205
        .. todo ::
206
207
            Use :class:`~blocks.bricks.Softmax`.
208
209
        Parameters
210
        ----------
211
        energies : :class:`~theano.Variable`
212
            The energies. Must be of the same shape as the mask.
213
        attended_mask : :class:`~theano.Variable`
214
            The mask for the attended. The index in the sequence must be
215
            the first dimension.
216
217
        Returns
218
        -------
219
        weights : :class:`~theano.Variable`
220
            Summing to 1 non-negative weights of the same shape
221
            as `energies`.
222
223
        """
224
        # Stabilize energies first and then exponentiate
225
        energies = energies - energies.max(axis=0)
226
        unnormalized_weights = tensor.exp(energies)
227
        if attended_mask:
228
            unnormalized_weights *= attended_mask
229
230
        # If mask consists of all zeros use 1 as the normalization coefficient
231
        normalization = (unnormalized_weights.sum(axis=0) +
232
                         tensor.all(1 - attended_mask, axis=0))
233
        return unnormalized_weights / normalization
234
235
    @application
236
    def compute_weighted_averages(self, weights, attended):
237
        """Compute weighted averages of the attended sequence vectors.
238
239
        Parameters
240
        ----------
241
        weights : :class:`~theano.Variable`
242
            The weights. The shape must be equal to the attended shape
243
            without the last dimension.
244
        attended : :class:`~theano.Variable`
245
            The attended. The index in the sequence must be the first
246
            dimension.
247
248
        Returns
249
        -------
250
        weighted_averages : :class:`~theano.Variable`
251
            The weighted averages of the attended elements. The shape
252
            is equal to the attended shape with the first dimension
253
            dropped.
254
255
        """
256
        return (tensor.shape_padright(weights) * attended).sum(axis=0)
257
258
259
class SequenceContentAttention(GenericSequenceAttention, Initializable):
260
    """Attention mechanism that looks for relevant content in a sequence.
261
262
    This is the attention mechanism used in [BCB]_. The idea in a nutshell:
263
264
    1. The states and the sequence are transformed independently,
265
266
    2. The transformed states are summed with every transformed sequence
267
       element to obtain *match vectors*,
268
269
    3. A match vector is transformed into a single number interpreted as
270
       *energy*,
271
272
    4. Energies are normalized in softmax-like fashion. The resulting
273
       summing to one weights are called *attention weights*,
274
275
    5. Weighted average of the sequence elements with attention weights
276
       is computed.
277
278
    In terms of the :class:`AbstractAttention` documentation, the sequence
279
    is the attended. The weighted averages from 5 and the attention
280
    weights from 4 form the set of glimpses produced by this attention
281
    mechanism.
282
283
    Parameters
284
    ----------
285
    state_names : list of str
286
        The names of the network states.
287
    attended_dim : int
288
        The dimension of the sequence elements.
289
    match_dim : int
290
        The dimension of the match vector.
291
    state_transformer : :class:`~.bricks.Brick`
292
        A prototype for state transformations. If ``None``,
293
        a linear transformation is used.
294
    attended_transformer : :class:`.Feedforward`
295
        The transformation to be applied to the sequence. If ``None`` an
296
        affine transformation is used.
297
    energy_computer : :class:`.Feedforward`
298
        Computes energy from the match vector. If ``None``, an affine
299
        transformations preceeded by :math:`tanh` is used.
300
301
    Notes
302
    -----
303
    See :class:`.Initializable` for initialization parameters.
304
305
    .. [BCB] Dzmitry Bahdanau, Kyunghyun Cho and Yoshua Bengio. Neural
306
       Machine Translation by Jointly Learning to Align and Translate.
307
308
    """
309
    @lazy(allocation=['match_dim'])
310
    def __init__(self, match_dim, state_transformer=None,
311
                 attended_transformer=None, energy_computer=None, **kwargs):
312
        if not state_transformer:
313
            state_transformer = Linear(use_bias=False)
314
        self.match_dim = match_dim
315
        self.state_transformer = state_transformer
316
317
        self.state_transformers = Parallel(input_names=kwargs['state_names'],
318
                                           prototype=state_transformer,
319
                                           name="state_trans")
320
        if not attended_transformer:
321
            attended_transformer = Linear(name="preprocess")
322
        if not energy_computer:
323
            energy_computer = ShallowEnergyComputer(name="energy_comp")
324
        self.attended_transformer = attended_transformer
325
        self.energy_computer = energy_computer
326
327
        children = [self.state_transformers, attended_transformer,
328
                    energy_computer]
329
        kwargs.setdefault('children', []).extend(children)
330
        super(SequenceContentAttention, self).__init__(**kwargs)
331
332
    def _push_allocation_config(self):
333
        self.state_transformers.input_dims = self.state_dims
334
        self.state_transformers.output_dims = [self.match_dim
335
                                               for name in self.state_names]
336
        self.attended_transformer.input_dim = self.attended_dim
337
        self.attended_transformer.output_dim = self.match_dim
338
        self.energy_computer.input_dim = self.match_dim
339
        self.energy_computer.output_dim = 1
340
341
    @application
342
    def compute_energies(self, attended, preprocessed_attended, states):
343
        if not preprocessed_attended:
344
            preprocessed_attended = self.preprocess(attended)
345
        transformed_states = self.state_transformers.apply(as_dict=True,
346
                                                           **states)
347
        # Broadcasting of transformed states should be done automatically
348
        match_vectors = sum(transformed_states.values(),
349
                            preprocessed_attended)
350
        energies = self.energy_computer.apply(match_vectors).reshape(
351
            match_vectors.shape[:-1], ndim=match_vectors.ndim - 1)
352
        return energies
353
354
    @application(outputs=['weighted_averages', 'weights'])
355
    def take_glimpses(self, attended, preprocessed_attended=None,
356
                      attended_mask=None, **states):
357
        r"""Compute attention weights and produce glimpses.
358
359
        Parameters
360
        ----------
361
        attended : :class:`~tensor.TensorVariable`
362
            The sequence, time is the 1-st dimension.
363
        preprocessed_attended : :class:`~tensor.TensorVariable`
364
            The preprocessed sequence. If ``None``, is computed by calling
365
            :meth:`preprocess`.
366
        attended_mask : :class:`~tensor.TensorVariable`
367
            A 0/1 mask specifying available data. 0 means that the
368
            corresponding sequence element is fake.
369
        \*\*states
370
            The states of the network.
371
372
        Returns
373
        -------
374
        weighted_averages : :class:`~theano.Variable`
375
            Linear combinations of sequence elements with the attention
376
            weights.
377
        weights : :class:`~theano.Variable`
378
            The attention weights. The first dimension is batch, the second
379
            is time.
380
381
        """
382
        energies = self.compute_energies(attended, preprocessed_attended,
383
                                         states)
384
        weights = self.compute_weights(energies, attended_mask)
385
        weighted_averages = self.compute_weighted_averages(weights, attended)
386
        return weighted_averages, weights.T
387
388
    @take_glimpses.property('inputs')
389
    def take_glimpses_inputs(self):
390
        return (['attended', 'preprocessed_attended', 'attended_mask'] +
391
                self.state_names)
392
393
    @application(outputs=['weighted_averages', 'weights'])
394
    def initial_glimpses(self, batch_size, attended):
395
        return [tensor.zeros((batch_size, self.attended_dim)),
396
                tensor.zeros((batch_size, attended.shape[0]))]
397
398
    @application(inputs=['attended'], outputs=['preprocessed_attended'])
399
    def preprocess(self, attended):
400
        """Preprocess the sequence for computing attention weights.
401
402
        Parameters
403
        ----------
404
        attended : :class:`~tensor.TensorVariable`
405
            The attended sequence, time is the 1-st dimension.
406
407
        """
408
        return self.attended_transformer.apply(attended)
409
410
    def get_dim(self, name):
411
        if name in ['weighted_averages']:
412
            return self.attended_dim
413
        if name in ['weights']:
414
            return 0
415
        return super(SequenceContentAttention, self).get_dim(name)
416
417
418
class ShallowEnergyComputer(Sequence, Initializable, Feedforward):
419
    """A simple energy computer: first tanh, then weighted sum.
420
421
    Parameters
422
    ----------
423
    use_bias : bool, optional
424
        Whether a bias should be added to the energies. Does not change
425
        anything if softmax normalization is used to produce the attention
426
        weights, but might be useful when e.g. spherical softmax is used.
427
428
    """
429
    @lazy()
430
    def __init__(self, use_bias=False, **kwargs):
431
        super(ShallowEnergyComputer, self).__init__(
432
            [Tanh().apply, Linear(use_bias=use_bias).apply], **kwargs)
433
434
    @property
435
    def input_dim(self):
436
        return self.children[1].input_dim
437
438
    @input_dim.setter
439
    def input_dim(self, value):
440
        self.children[1].input_dim = value
441
442
    @property
443
    def output_dim(self):
444
        return self.children[1].output_dim
445
446
    @output_dim.setter
447
    def output_dim(self, value):
448
        self.children[1].output_dim = value
449
450
451
@add_metaclass(ABCMeta)
452
class AbstractAttentionRecurrent(BaseRecurrent):
453
    """The interface for attention-equipped recurrent transitions.
454
455
    When a recurrent network is equipped with an attention mechanism its
456
    transition typically consists of two steps: (1) the glimpses are taken
457
    by the attention mechanism and (2) the next states are computed using
458
    the current states and the glimpses. It is required for certain
459
    usecases (such as sequence generator) that apart from a do-it-all
460
    recurrent application method interfaces for the first step and
461
    the second steps of the transition are provided.
462
463
    """
464
    @abstractmethod
465
    def apply(self, **kwargs):
466
        """Compute next states taking glimpses on the way."""
467
        pass
468
469
    @abstractmethod
470
    def take_glimpses(self, **kwargs):
471
        """Compute glimpses given the current states."""
472
        pass
473
474
    @abstractmethod
475
    def compute_states(self, **kwargs):
476
        """Compute next states given current states and glimpses."""
477
        pass
478
479
480
class AttentionRecurrent(AbstractAttentionRecurrent, Initializable):
481
    """Combines an attention mechanism and a recurrent transition.
482
483
    This brick equips a recurrent transition with an attention mechanism.
484
    In order to do this two more contexts are added: one to be attended and
485
    a mask for it. It is also possible to use the contexts of the given
486
    recurrent transition for these purposes and not add any new ones,
487
    see `add_context` parameter.
488
489
    At the beginning of each step attention mechanism produces glimpses;
490
    these glimpses together with the current states are used to compute the
491
    next state and finish the transition. In some cases glimpses from the
492
    previous steps are also necessary for the attention mechanism, e.g.
493
    in order to focus on an area close to the one from the previous step.
494
    This is also supported: such glimpses become states of the new
495
    transition.
496
497
    To let the user control the way glimpses are used, this brick also
498
    takes a "distribute" brick as parameter that distributes the
499
    information from glimpses across the sequential inputs of the wrapped
500
    recurrent transition.
501
502
    Parameters
503
    ----------
504
    transition : :class:`.BaseRecurrent`
505
        The recurrent transition.
506
    attention : :class:`~.bricks.Brick`
507
        The attention mechanism.
508
    distribute : :class:`~.bricks.Brick`, optional
509
        Distributes the information from glimpses across the input
510
        sequences of the transition. By default a :class:`.Distribute` is
511
        used, and those inputs containing the "mask" substring in their
512
        name are not affected.
513
    add_contexts : bool, optional
514
        If ``True``, new contexts for the attended and the attended mask
515
        are added to this transition, otherwise existing contexts of the
516
        wrapped transition are used. ``True`` by default.
517
    attended_name : str
518
        The name of the attended context. If ``None``, "attended"
519
        or the first context of the recurrent transition is used
520
        depending on the value of `add_contents` flag.
521
    attended_mask_name : str
522
        The name of the mask for the attended context. If ``None``,
523
        "attended_mask" or the second context of the recurrent transition
524
        is used depending on the value of `add_contents` flag.
525
526
    Notes
527
    -----
528
    See :class:`.Initializable` for initialization parameters.
529
530
    Wrapping your recurrent brick with this class makes all the
531
    states mandatory. If you feel this is a limitation for you, try
532
    to make it better! This restriction does not apply to sequences
533
    and contexts: those keep being as optional as they were for
534
    your brick.
535
536
    Those coming to Blocks from Groundhog might recognize that this is
537
    a `RecurrentLayerWithSearch`, but on steroids :)
538
539
    """
540
    def __init__(self, transition, attention, distribute=None,
541
                 add_contexts=True,
542
                 attended_name=None, attended_mask_name=None,
543
                 **kwargs):
544
        self._sequence_names = list(transition.apply.sequences)
545
        self._state_names = list(transition.apply.states)
546
        self._context_names = list(transition.apply.contexts)
547
        if add_contexts:
548
            if not attended_name:
549
                attended_name = 'attended'
550
            if not attended_mask_name:
551
                attended_mask_name = 'attended_mask'
552
            self._context_names += [attended_name, attended_mask_name]
553
        else:
554
            attended_name = self._context_names[0]
555
            attended_mask_name = self._context_names[1]
556
        if not distribute:
557
            normal_inputs = [name for name in self._sequence_names
558
                             if 'mask' not in name]
559
            distribute = Distribute(normal_inputs,
560
                                    attention.take_glimpses.outputs[0])
561
562
        self.transition = transition
563
        self.attention = attention
564
        self.distribute = distribute
565
        self.add_contexts = add_contexts
566
        self.attended_name = attended_name
567
        self.attended_mask_name = attended_mask_name
568
569
        self.preprocessed_attended_name = "preprocessed_" + self.attended_name
570
571
        self._glimpse_names = self.attention.take_glimpses.outputs
572
        # We need to determine which glimpses are fed back.
573
        # Currently we extract it from `take_glimpses` signature.
574
        self.previous_glimpses_needed = [
575
            name for name in self._glimpse_names
576
            if name in self.attention.take_glimpses.inputs]
577
578
        children = [self.transition, self.attention, self.distribute]
579
        kwargs.setdefault('children', []).extend(children)
580
        super(AttentionRecurrent, self).__init__(**kwargs)
581
582
    def _push_allocation_config(self):
583
        self.attention.state_dims = self.transition.get_dims(
584
            self.attention.state_names)
585
        self.attention.attended_dim = self.get_dim(self.attended_name)
586
        self.distribute.source_dim = self.attention.get_dim(
587
            self.distribute.source_name)
588
        self.distribute.target_dims = self.transition.get_dims(
589
            self.distribute.target_names)
590
591
    @application
592
    def take_glimpses(self, **kwargs):
593
        r"""Compute glimpses with the attention mechanism.
594
595
        A thin wrapper over `self.attention.take_glimpses`: takes care
596
        of choosing and renaming the necessary arguments.
597
598
        Parameters
599
        ----------
600
        \*\*kwargs
601
            Must contain the attended, previous step states and glimpses.
602
            Can optionaly contain the attended mask and the preprocessed
603
            attended.
604
605
        Returns
606
        -------
607
        glimpses : list of :class:`~tensor.TensorVariable`
608
            Current step glimpses.
609
610
        """
611
        states = dict_subset(kwargs, self._state_names, pop=True)
612
        glimpses = dict_subset(kwargs, self._glimpse_names, pop=True)
613
        glimpses_needed = dict_subset(glimpses, self.previous_glimpses_needed)
614
        result = self.attention.take_glimpses(
615
            kwargs.pop(self.attended_name),
616
            kwargs.pop(self.preprocessed_attended_name, None),
617
            kwargs.pop(self.attended_mask_name, None),
618
            **dict_union(states, glimpses_needed))
619
        # At this point kwargs may contain additional items.
620
        # e.g. AttentionRecurrent.transition.apply.contexts
621
        return result
622
623
    @take_glimpses.property('outputs')
624
    def take_glimpses_outputs(self):
625
        return self._glimpse_names
626
627
    @application
628
    def compute_states(self, **kwargs):
629
        r"""Compute current states when glimpses have already been computed.
630
631
        Combines an application of the `distribute` that alter the
632
        sequential inputs of the wrapped transition and an application of
633
        the wrapped transition. All unknown keyword arguments go to
634
        the wrapped transition.
635
636
        Parameters
637
        ----------
638
        \*\*kwargs
639
            Should contain everything what `self.transition` needs
640
            and in addition the current glimpses.
641
642
        Returns
643
        -------
644
        current_states : list of :class:`~tensor.TensorVariable`
645
            Current states computed by `self.transition`.
646
647
        """
648
        # make sure we are not popping the mask
649
        normal_inputs = [name for name in self._sequence_names
650
                         if 'mask' not in name]
651
        sequences = dict_subset(kwargs, normal_inputs, pop=True)
652
        glimpses = dict_subset(kwargs, self._glimpse_names, pop=True)
653
        if self.add_contexts:
654
            kwargs.pop(self.attended_name)
655
            # attended_mask_name can be optional
656
            kwargs.pop(self.attended_mask_name, None)
657
658
        sequences.update(self.distribute.apply(
659
            as_dict=True, **dict_subset(dict_union(sequences, glimpses),
660
                                        self.distribute.apply.inputs)))
661
        current_states = self.transition.apply(
662
            iterate=False, as_list=True,
663
            **dict_union(sequences, kwargs))
664
        return current_states
665
666
    @compute_states.property('outputs')
667
    def compute_states_outputs(self):
668
        return self._state_names
669
670
    @recurrent
671
    def do_apply(self, **kwargs):
672
        r"""Process a sequence attending the attended context every step.
673
674
        In addition to the original sequence this method also requires
675
        its preprocessed version, the one computed by the `preprocess`
676
        method of the attention mechanism. Unknown keyword arguments
677
        are passed to the wrapped transition.
678
679
        Parameters
680
        ----------
681
        \*\*kwargs
682
            Should contain current inputs, previous step states, contexts,
683
            the preprocessed attended context, previous step glimpses.
684
685
        Returns
686
        -------
687
        outputs : list of :class:`~tensor.TensorVariable`
688
            The current step states and glimpses.
689
690
        """
691
        attended = kwargs[self.attended_name]
692
        preprocessed_attended = kwargs.pop(self.preprocessed_attended_name)
693
        attended_mask = kwargs.get(self.attended_mask_name)
694
        sequences = dict_subset(kwargs, self._sequence_names, pop=True,
695
                                must_have=False)
696
        states = dict_subset(kwargs, self._state_names, pop=True)
697
        glimpses = dict_subset(kwargs, self._glimpse_names, pop=True)
698
699
        current_glimpses = self.take_glimpses(
700
            as_dict=True,
701
            **dict_union(
702
                states, glimpses,
703
                {self.attended_name: attended,
704
                 self.attended_mask_name: attended_mask,
705
                 self.preprocessed_attended_name: preprocessed_attended}))
706
        current_states = self.compute_states(
707
            as_list=True,
708
            **dict_union(sequences, states, current_glimpses, kwargs))
709
        return current_states + list(current_glimpses.values())
710
711
    @do_apply.property('sequences')
712
    def do_apply_sequences(self):
713
        return self._sequence_names
714
715
    @do_apply.property('contexts')
716
    def do_apply_contexts(self):
717
        return self._context_names + [self.preprocessed_attended_name]
718
719
    @do_apply.property('states')
720
    def do_apply_states(self):
721
        return self._state_names + self._glimpse_names
722
723
    @do_apply.property('outputs')
724
    def do_apply_outputs(self):
725
        return self._state_names + self._glimpse_names
726
727
    @application
728
    def apply(self, **kwargs):
729
        """Preprocess a sequence attending the attended context at every step.
730
731
        Preprocesses the attended context and runs :meth:`do_apply`. See
732
        :meth:`do_apply` documentation for further information.
733
734
        """
735
        preprocessed_attended = self.attention.preprocess(
736
            kwargs[self.attended_name])
737
        return self.do_apply(
738
            **dict_union(kwargs,
739
                         {self.preprocessed_attended_name:
740
                          preprocessed_attended}))
741
742
    @apply.delegate
743
    def apply_delegate(self):
744
        # TODO: Nice interface for this trick?
745
        return self.do_apply.__get__(self, None)
746
747
    @apply.property('contexts')
748
    def apply_contexts(self):
749
        return self._context_names
750
751
    @application
752
    def initial_states(self, batch_size, **kwargs):
753
        return (pack(self.transition.initial_states(
754
                     batch_size, **kwargs)) +
755
                pack(self.attention.initial_glimpses(
756
                     batch_size, kwargs[self.attended_name])))
757
758
    @initial_states.property('outputs')
759
    def initial_states_outputs(self):
760
        return self.do_apply.states
761
762
    def get_dim(self, name):
763
        if name in self._glimpse_names:
764
            return self.attention.get_dim(name)
765
        if name == self.preprocessed_attended_name:
766
            (original_name,) = self.attention.preprocess.outputs
767
            return self.attention.get_dim(original_name)
768
        if self.add_contexts:
769
            if name == self.attended_name:
770
                return self.attention.get_dim(
771
                    self.attention.take_glimpses.inputs[0])
772
            if name == self.attended_mask_name:
773
                return 0
774
        return self.transition.get_dim(name)
775