Issues (119)

blocks/bricks/sequence_generators.py (6 issues)

1
"""Sequence generation framework.
2
3
Recurrent networks are often used to generate/model sequences.
4
Examples include language modelling, machine translation, handwriting
5
synthesis, etc.. A typical pattern in this context is that
6
sequence elements are generated one often another, and every generated
7
element is fed back into the recurrent network state. Sometimes
8
also an attention mechanism is used to condition sequence generation
9
on some structured input like another sequence or an image.
10
11
This module provides :class:`SequenceGenerator` that builds a sequence
12
generating network from three main components:
13
14
* a core recurrent transition, e.g. :class:`~blocks.bricks.recurrent.LSTM`
15
  or :class:`~blocks.bricks.recurrent.GatedRecurrent`
16
17
* a readout component that can produce sequence elements using
18
  the network state and the information from the attention mechanism
19
20
* an attention mechanism (see :mod:`~blocks.bricks.attention` for
21
  more information)
22
23
Implementation-wise :class:`SequenceGenerator` fully relies on
24
:class:`BaseSequenceGenerator`. At the level of the latter an
25
attention is mandatory, moreover it must be a part of the recurrent
26
transition (see :class:`~blocks.bricks.attention.AttentionRecurrent`).
27
To simulate optional attention, :class:`SequenceGenerator` wraps the
28
pure recurrent network in :class:`FakeAttentionRecurrent`.
29
30
"""
31
from abc import ABCMeta, abstractmethod
32
33
from six import add_metaclass
34
from theano import tensor
35
36
from blocks.bricks import Initializable, Random, Bias, NDimensionalSoftmax
37
from blocks.bricks.base import application, Brick, lazy
38
from blocks.bricks.parallel import Fork, Merge
39
from blocks.bricks.lookup import LookupTable
40
from blocks.bricks.recurrent import recurrent
41
from blocks.bricks.attention import (
42
    AbstractAttentionRecurrent, AttentionRecurrent)
43
from blocks.roles import add_role, COST
44
from blocks.utils import dict_union, dict_subset
45
46
47
class BaseSequenceGenerator(Initializable):
48
    r"""A generic sequence generator.
49
50
    This class combines two components, a readout network and an
51
    attention-equipped recurrent transition, into a context-dependent
52
    sequence generator. Third component must be also given which
53
    forks feedback from the readout network to obtain inputs for the
54
    transition.
55
56
    The class provides two methods: :meth:`generate` and :meth:`cost`. The
57
    former is to actually generate sequences and the latter is to compute
58
    the cost of generating given sequences.
59
60
    The generation algorithm description follows.
61
62
    **Definitions and notation:**
63
64
    * States :math:`s_i` of the generator are the states of the transition
65
      as specified in `transition.state_names`.
66
67
    * Contexts of the generator are the contexts of the
68
      transition as specified in `transition.context_names`.
69
70
    * Glimpses :math:`g_i` are intermediate entities computed at every
71
      generation step from states, contexts and the previous step glimpses.
72
      They are computed in the transition's `apply` method when not given
73
      or by explicitly calling the transition's `take_glimpses` method. The
74
      set of glimpses considered is specified in
75
      `transition.glimpse_names`.
76
77
    * Outputs :math:`y_i` are produced at every step and form the output
78
      sequence. A generation cost :math:`c_i` is assigned to each output.
79
80
    **Algorithm:**
81
82
    1. Initialization.
83
84
       .. math::
85
86
           y_0 = readout.initial\_outputs(contexts)\\
87
           s_0, g_0 = transition.initial\_states(contexts)\\
88
           i = 1\\
89
90
       By default all recurrent bricks from :mod:`~blocks.bricks.recurrent`
91
       have trainable initial states initialized with zeros. Subclass them
92
       or :class:`~blocks.bricks.recurrent.BaseRecurrent` directly to get
93
       custom initial states.
94
95
    2. New glimpses are computed:
96
97
       .. math:: g_i = transition.take\_glimpses(
98
           s_{i-1}, g_{i-1}, contexts)
99
100
    3. A new output is generated by the readout and its cost is
101
       computed:
102
103
       .. math::
104
105
            f_{i-1} = readout.feedback(y_{i-1}) \\
106
            r_i = readout.readout(f_{i-1}, s_{i-1}, g_i, contexts) \\
107
            y_i = readout.emit(r_i) \\
108
            c_i = readout.cost(r_i, y_i)
109
110
       Note that the *new* glimpses and the *old* states are used at this
111
       step. The reason for not merging all readout methods into one is
112
       to make an efficient implementation of :meth:`cost` possible.
113
114
    4. New states are computed and iteration is done:
115
116
       .. math::
117
118
           f_i = readout.feedback(y_i) \\
119
           s_i = transition.compute\_states(s_{i-1}, g_i,
120
                fork.apply(f_i), contexts) \\
121
           i = i + 1
122
123
    5. Back to step 2 if the desired sequence
124
       length has not been yet reached.
125
126
    | A scheme of the algorithm described above follows.
127
128
    .. image:: /_static/sequence_generator_scheme.png
129
            :height: 500px
130
            :width: 500px
131
132
    ..
133
134
    Parameters
135
    ----------
136
    readout : instance of :class:`AbstractReadout`
137
        The readout component of the sequence generator.
138
    transition : instance of :class:`AbstractAttentionRecurrent`
139
        The transition component of the sequence generator.
140
    fork : :class:`~.bricks.Brick`
141
        The brick to compute the transition's inputs from the feedback.
142
143
    See Also
144
    --------
145
    :class:`.Initializable` : for initialization parameters
146
147
    :class:`SequenceGenerator` : more user friendly interface to this\
148
        brick
149
150
    """
151
    @lazy()
152
    def __init__(self, readout, transition, fork, **kwargs):
153
        self.readout = readout
154
        self.transition = transition
155
        self.fork = fork
156
157
        children = [self.readout, self.fork, self.transition]
158
        kwargs.setdefault('children', []).extend(children)
159
        super(BaseSequenceGenerator, self).__init__(**kwargs)
160
161
    @property
162
    def _state_names(self):
163
        return self.transition.compute_states.outputs
164
165
    @property
166
    def _context_names(self):
167
        return self.transition.apply.contexts
168
169
    @property
170
    def _glimpse_names(self):
171
        return self.transition.take_glimpses.outputs
172
173
    def _push_allocation_config(self):
174
        # Configure readout. That involves `get_dim` requests
175
        # to the transition. To make sure that it answers
176
        # correctly we should finish its configuration first.
177
        self.transition.push_allocation_config()
178
        transition_sources = (self._state_names + self._context_names +
179
                              self._glimpse_names)
180
        self.readout.source_dims = [self.transition.get_dim(name)
181
                                    if name in transition_sources
182
                                    else self.readout.get_dim(name)
183
                                    for name in self.readout.source_names]
184
185
        # Configure fork. For similar reasons as outlined above,
186
        # first push `readout` configuration.
187
        self.readout.push_allocation_config()
188
        feedback_name, = self.readout.feedback.outputs
189
        self.fork.input_dim = self.readout.get_dim(feedback_name)
190
        self.fork.output_dims = self.transition.get_dims(
191
            self.fork.apply.outputs)
192
193
    @application
194
    def cost(self, application_call, outputs, mask=None, **kwargs):
195
        """Returns the average cost over the minibatch.
196
197
        The cost is computed by averaging the sum of per token costs for
198
        each sequence over the minibatch.
199
200
        .. warning::
201
            Note that, the computed cost can be problematic when batches
202
            consist of vastly different sequence lengths.
203
204
        Parameters
205
        ----------
206
        outputs : :class:`~tensor.TensorVariable`
207
            The 3(2) dimensional tensor containing output sequences.
208
            The axis 0 must stand for time, the axis 1 for the
209
            position in the batch.
210
        mask : :class:`~tensor.TensorVariable`
211
            The binary matrix identifying fake outputs.
212
213
        Returns
214
        -------
215
        cost : :class:`~tensor.Variable`
216
            Theano variable for cost, computed by summing over timesteps
217
            and then averaging over the minibatch.
218
219
        Notes
220
        -----
221
        The contexts are expected as keyword arguments.
222
223
        Adds average cost per sequence element `AUXILIARY` variable to
224
        the computational graph with name ``per_sequence_element``.
225
226
        """
227
        # Compute the sum of costs
228
        costs = self.cost_matrix(outputs, mask=mask, **kwargs)
229
        cost = tensor.mean(costs.sum(axis=0))
230
        add_role(cost, COST)
231
232
        # Add auxiliary variable for per sequence element cost
233
        application_call.add_auxiliary_variable(
234
            (costs.sum() / mask.sum()) if mask is not None else costs.mean(),
235
            name='per_sequence_element')
236
        return cost
237
238
    @application
239
    def cost_matrix(self, application_call, outputs, mask=None, **kwargs):
240
        """Returns generation costs for output sequences.
241
242
        See Also
243
        --------
244
        :meth:`cost` : Scalar cost.
245
246
        """
247
        # We assume the data has axes (time, batch, features, ...)
248
        batch_size = outputs.shape[1]
249
250
        # Prepare input for the iterative part
251
        states = dict_subset(kwargs, self._state_names, must_have=False)
252
        # masks in context are optional (e.g. `attended_mask`)
253
        contexts = dict_subset(kwargs, self._context_names, must_have=False)
254
        feedback = self.readout.feedback(outputs)
255
        inputs = self.fork.apply(feedback, as_dict=True)
256
257
        # Run the recurrent network
258
        results = self.transition.apply(
259
            mask=mask, return_initial_states=True, as_dict=True,
260
            **dict_union(inputs, states, contexts))
261
262
        # Separate the deliverables. The last states are discarded: they
263
        # are not used to predict any output symbol. The initial glimpses
264
        # are discarded because they are not used for prediction.
265
        # Remember, glimpses are computed _before_ output stage, states are
266
        # computed after.
267
        states = {name: results[name][:-1] for name in self._state_names}
268
        glimpses = {name: results[name][1:] for name in self._glimpse_names}
269
270
        # Compute the cost
271
        feedback = tensor.roll(feedback, 1, 0)
272
        feedback = tensor.set_subtensor(
273
            feedback[0],
274
            self.readout.feedback(self.readout.initial_outputs(batch_size)))
275
        readouts = self.readout.readout(
276
            feedback=feedback, **dict_union(states, glimpses, contexts))
277
        costs = self.readout.cost(readouts, outputs)
278
        if mask is not None:
279
            costs *= mask
280
281
        for name, variable in list(glimpses.items()) + list(states.items()):
282
            application_call.add_auxiliary_variable(
283
                variable.copy(), name=name)
284
285
        # This variables can be used to initialize the initial states of the
286
        # next batch using the last states of the current batch.
287
        for name in self._state_names + self._glimpse_names:
288
            application_call.add_auxiliary_variable(
289
                results[name][-1].copy(), name=name+"_final_value")
290
291
        return costs
292
293
    @recurrent
294
    def generate(self, outputs, **kwargs):
295
        """A sequence generation step.
296
297
        Parameters
298
        ----------
299
        outputs : :class:`~tensor.TensorVariable`
300
            The outputs from the previous step.
301
302
        Notes
303
        -----
304
        The contexts, previous states and glimpses are expected as keyword
305
        arguments.
306
307
        """
308
        states = dict_subset(kwargs, self._state_names)
309
        # masks in context are optional (e.g. `attended_mask`)
310
        contexts = dict_subset(kwargs, self._context_names, must_have=False)
311
        glimpses = dict_subset(kwargs, self._glimpse_names)
312
313
        next_glimpses = self.transition.take_glimpses(
314
            as_dict=True, **dict_union(states, glimpses, contexts))
315
        next_readouts = self.readout.readout(
316
            feedback=self.readout.feedback(outputs),
317
            **dict_union(states, next_glimpses, contexts))
318
        next_outputs = self.readout.emit(next_readouts)
319
        next_costs = self.readout.cost(next_readouts, next_outputs)
320
        next_feedback = self.readout.feedback(next_outputs)
321
        next_inputs = (self.fork.apply(next_feedback, as_dict=True)
322
                       if self.fork else {'feedback': next_feedback})
323
        next_states = self.transition.compute_states(
324
            as_list=True,
325
            **dict_union(next_inputs, states, next_glimpses, contexts))
326
        return (next_states + [next_outputs] +
327
                list(next_glimpses.values()) + [next_costs])
328
329
    @generate.delegate
330
    def generate_delegate(self):
331
        return self.transition.apply
332
333
    @generate.property('states')
334
    def generate_states(self):
335
        return self._state_names + ['outputs'] + self._glimpse_names
336
337
    @generate.property('outputs')
338
    def generate_outputs(self):
339
        return (self._state_names + ['outputs'] +
340
                self._glimpse_names + ['costs'])
341
342
    def get_dim(self, name):
343
        if name in (self._state_names + self._context_names +
344
                    self._glimpse_names):
345
            return self.transition.get_dim(name)
346
        elif name == 'outputs':
347
            return self.readout.get_dim(name)
348
        return super(BaseSequenceGenerator, self).get_dim(name)
349
350
    @application
351
    def initial_states(self, batch_size, *args, **kwargs):
352
        # TODO: support dict of outputs for application methods
353
        # to simplify this code.
354
        state_dict = dict(
355
            self.transition.initial_states(
356
                batch_size, as_dict=True, *args, **kwargs),
357
            outputs=self.readout.initial_outputs(batch_size))
358
        return [state_dict[state_name]
359
                for state_name in self.generate.states]
360
361
    @initial_states.property('outputs')
362
    def initial_states_outputs(self):
363
        return self.generate.states
364
365
366
@add_metaclass(ABCMeta)
0 ignored issues
show
This abstract class seems to be used only once.

Abstract classes which are used only once can usually be inlined into the class which already uses this abstract class.

Loading history...
367
class AbstractReadout(Initializable):
368
    """The interface for the readout component of a sequence generator.
369
370
    The readout component of a sequence generator is a bridge between
371
    the core recurrent network and the output sequence.
372
373
    Parameters
374
    ----------
375
    source_names : list
376
        A list of the source names (outputs) that are needed for the
377
        readout part e.g. ``['states']`` or
378
        ``['states', 'weighted_averages']`` or ``['states', 'feedback']``.
379
    readout_dim : int
380
        The dimension of the readout.
381
382
    Attributes
383
    ----------
384
    source_names : list
385
    readout_dim : int
386
387
    See Also
388
    --------
389
    :class:`BaseSequenceGenerator` : see how exactly a readout is used
390
391
    :class:`Readout` : the typically used readout brick
392
393
    """
394
    @lazy(allocation=['source_names', 'readout_dim'])
395
    def __init__(self, source_names, readout_dim, **kwargs):
396
        self.source_names = source_names
397
        self.readout_dim = readout_dim
398
        super(AbstractReadout, self).__init__(**kwargs)
399
400
    @abstractmethod
401
    def emit(self, readouts):
402
        """Produce outputs from readouts.
403
404
        Parameters
405
        ----------
406
        readouts : :class:`~theano.Variable`
407
            Readouts produced by the :meth:`readout` method of
408
            a `(batch_size, readout_dim)` shape.
409
410
        """
411
        pass
412
413
    @abstractmethod
414
    def cost(self, readouts, outputs):
415
        """Compute generation cost of outputs given readouts.
416
417
        Parameters
418
        ----------
419
        readouts : :class:`~theano.Variable`
420
            Readouts produced by the :meth:`readout` method
421
            of a `(..., readout dim)` shape.
422
        outputs : :class:`~theano.Variable`
423
            Outputs whose cost should be computed. Should have as many
424
            or one less dimensions compared to `readout`. If readout has
425
            `n` dimensions, first `n - 1` dimensions of `outputs` should
426
            match with those of `readouts`.
427
428
        """
429
        pass
430
431
    @abstractmethod
432
    def initial_outputs(self, batch_size):
433
        """Compute initial outputs for the generator's first step.
434
435
        In the notation from the :class:`BaseSequenceGenerator`
436
        documentation this method should compute :math:`y_0`.
437
438
        """
439
        pass
440
441
    @abstractmethod
442
    def readout(self, **kwargs):
443
        r"""Compute the readout vector from states, glimpses, etc.
444
445
        Parameters
446
        ----------
447
        \*\*kwargs: dict
448
            Contains sequence generator states, glimpses,
449
            contexts and feedback from the previous outputs.
450
451
        """
452
        pass
453
454
    @abstractmethod
455
    def feedback(self, outputs):
456
        """Feeds outputs back to be used as inputs of the transition."""
457
        pass
458
459
460
class Readout(AbstractReadout):
461
    r"""Readout brick with separated emitter and feedback parts.
462
463
    :class:`Readout` combines a few bits and pieces into an object
464
    that can be used as the readout component in
465
    :class:`BaseSequenceGenerator`. This includes an emitter brick,
466
    to which :meth:`emit`, :meth:`cost` and :meth:`initial_outputs`
467
    calls are delegated, a feedback brick to which :meth:`feedback`
468
    functionality is delegated, and a pipeline to actually compute
469
    readouts from all the sources (see the `source_names` attribute
470
    of :class:`AbstractReadout`).
471
472
    The readout computation pipeline is constructed from `merge` and
473
    `post_merge` brick, whose responsibilites are described in the
474
    respective docstrings.
475
476
    Parameters
477
    ----------
478
    emitter : an instance of :class:`AbstractEmitter`
479
        The emitter component.
480
    feedback_brick : an instance of :class:`AbstractFeedback`
481
        The feedback component.
482
    merge : :class:`~.bricks.Brick`, optional
483
        A brick that takes the sources given in `source_names` as an input
484
        and combines them into a single output. If given, `merge_prototype`
485
        cannot be given.
486
    merge_prototype : :class:`.FeedForward`, optional
487
        If `merge` isn't given, the transformation given by
488
        `merge_prototype` is applied to each input before being summed. By
489
        default a :class:`.Linear` transformation without biases is used.
490
        If given, `merge` cannot be given.
491
    post_merge : :class:`.Feedforward`, optional
492
        This transformation is applied to the merged inputs. By default
493
        :class:`.Bias` is used.
494
    merged_dim : int, optional
495
        The input dimension of `post_merge` i.e. the output dimension of
496
        `merge` (or `merge_prototype`). If not give, it is assumed to be
497
        the same as `readout_dim` (i.e. `post_merge` is assumed to not
498
        change dimensions).
499
    \*\*kwargs : dict
500
        Passed to the parent's constructor.
501
502
    See Also
503
    --------
504
    :class:`BaseSequenceGenerator` : see how exactly a readout is used
505
506
    :class:`AbstractEmitter`, :class:`AbstractFeedback`
507
508
    """
509
    def __init__(self, emitter=None, feedback_brick=None,
510
                 merge=None, merge_prototype=None,
511
                 post_merge=None, merged_dim=None, **kwargs):
512
513
        if not emitter:
514
            emitter = TrivialEmitter(kwargs['readout_dim'])
515
        if not feedback_brick:
516
            feedback_brick = TrivialFeedback(kwargs['readout_dim'])
517
        if not merge:
518
            merge = Merge(input_names=kwargs['source_names'],
519
                          prototype=merge_prototype)
520
        if not post_merge:
521
            post_merge = Bias(dim=kwargs['readout_dim'])
522
        if not merged_dim:
523
            merged_dim = kwargs['readout_dim']
524
        self.emitter = emitter
525
        self.feedback_brick = feedback_brick
526
        self.merge = merge
527
        self.post_merge = post_merge
528
        self.merged_dim = merged_dim
529
530
        children = [self.emitter, self.feedback_brick, self.merge,
531
                    self.post_merge]
532
        kwargs.setdefault('children', []).extend(children)
533
        super(Readout, self).__init__(**kwargs)
534
535
    def _push_allocation_config(self):
536
        self.emitter.readout_dim = self.get_dim('readouts')
537
        self.feedback_brick.output_dim = self.get_dim('outputs')
538
        self.merge.input_names = self.source_names
539
        self.merge.input_dims = self.source_dims
0 ignored issues
show
The Instance of Readout does not seem to have a member named source_dims.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
540
        self.merge.output_dim = self.merged_dim
541
        self.post_merge.input_dim = self.merged_dim
542
        self.post_merge.output_dim = self.readout_dim
543
544
    @application
545
    def readout(self, **kwargs):
546
        merged = self.merge.apply(**{name: kwargs[name]
547
                                     for name in self.merge.input_names})
548
        merged = self.post_merge.apply(merged)
549
        return merged
550
551
    @application
552
    def emit(self, readouts):
553
        return self.emitter.emit(readouts)
554
555
    @application
556
    def cost(self, readouts, outputs):
557
        return self.emitter.cost(readouts, outputs)
558
559
    @application
560
    def initial_outputs(self, batch_size):
561
        return self.emitter.initial_outputs(batch_size)
562
563
    @application(outputs=['feedback'])
564
    def feedback(self, outputs):
565
        return self.feedback_brick.feedback(outputs)
566
567
    def get_dim(self, name):
568
        if name == 'outputs':
569
            return self.emitter.get_dim(name)
570
        elif name == 'feedback':
571
            return self.feedback_brick.get_dim(name)
572
        elif name == 'readouts':
573
            return self.readout_dim
574
        return super(Readout, self).get_dim(name)
575
576
577
@add_metaclass(ABCMeta)
578
class AbstractEmitter(Brick):
579
    """The interface for the emitter component of a readout.
580
581
    Attributes
582
    ----------
583
    readout_dim : int
584
        The dimension of the readout. Is given by the
585
        :class:`Readout` brick when allocation configuration
586
        is pushed.
587
588
    See Also
589
    --------
590
    :class:`Readout`
591
592
    :class:`SoftmaxEmitter` : for integer outputs
593
594
    Notes
595
    -----
596
    An important detail about the emitter cost is that it will be
597
    evaluated with inputs of different dimensions so it has to be
598
    flexible enough to handle this. The two ways in which it can be
599
    applied are:
600
601
        1. In :meth:BaseSequenceGenerator.cost_matrix where it will
602
        be applied to the whole sequence at once.
603
604
        2. In :meth:BaseSequenceGenerator.generate where it will be
605
        applied to only one step of the sequence.
606
607
    """
608
    @abstractmethod
609
    def emit(self, readouts):
610
        """Implements the respective method of :class:`Readout`."""
611
        pass
612
613
    @abstractmethod
614
    def cost(self, readouts, outputs):
615
        """Implements the respective method of :class:`Readout`."""
616
        pass
617
618
    @abstractmethod
619
    def initial_outputs(self, batch_size):
620
        """Implements the respective method of :class:`Readout`."""
621
        pass
622
623
624
@add_metaclass(ABCMeta)
625
class AbstractFeedback(Brick):
626
    """The interface for the feedback component of a readout.
627
628
    See Also
629
    --------
630
    :class:`Readout`
631
632
    :class:`LookupFeedback` for integer outputs
633
634
    """
635
    @abstractmethod
636
    def feedback(self, outputs):
637
        """Implements the respective method of :class:`Readout`."""
638
        pass
639
640
641
class TrivialEmitter(AbstractEmitter):
642
    """An emitter for the trivial case when readouts are outputs.
643
644
    Parameters
645
    ----------
646
    readout_dim : int
647
        The dimension of the readout.
648
649
    Notes
650
    -----
651
    By default :meth:`cost` always returns zero tensor.
652
653
    """
654
    @lazy(allocation=['readout_dim'])
655
    def __init__(self, readout_dim, **kwargs):
656
        super(TrivialEmitter, self).__init__(**kwargs)
657
        self.readout_dim = readout_dim
658
659
    @application
660
    def emit(self, readouts):
661
        return readouts
662
663
    @application
664
    def cost(self, readouts, outputs):
665
        return tensor.zeros_like(outputs)
666
667
    @application
668
    def initial_outputs(self, batch_size):
669
        return tensor.zeros((batch_size, self.readout_dim))
670
671
    def get_dim(self, name):
672
        if name == 'outputs':
673
            return self.readout_dim
674
        return super(TrivialEmitter, self).get_dim(name)
675
676
677
class SoftmaxEmitter(AbstractEmitter, Initializable, Random):
678
    """A softmax emitter for the case of integer outputs.
679
680
    Interprets readout elements as energies corresponding to their indices.
681
682
    Parameters
683
    ----------
684
    initial_output : int or a scalar :class:`~theano.Variable`
685
        The initial output.
686
687
    """
688
    def __init__(self, initial_output=0, **kwargs):
689
        self.initial_output = initial_output
690
        self.softmax = NDimensionalSoftmax()
691
        children = [self.softmax]
692
        kwargs.setdefault('children', []).extend(children)
693
        super(SoftmaxEmitter, self).__init__(**kwargs)
694
695
    @application
696
    def probs(self, readouts):
697
        return self.softmax.apply(readouts, extra_ndim=readouts.ndim - 2)
0 ignored issues
show
The keyword extra_ndim does not seem to exist for the method call.
Loading history...
698
699
    @application
700
    def emit(self, readouts):
701
        probs = self.probs(readouts)
702
        batch_size = probs.shape[0]
703
        pvals_flat = probs.reshape((batch_size, -1))
704
        generated = self.theano_rng.multinomial(pvals=pvals_flat)
705
        return generated.reshape(probs.shape).argmax(axis=-1)
706
707
    @application
708
    def cost(self, readouts, outputs):
709
        # WARNING: unfortunately this application method works
710
        # just fine when `readouts` and `outputs` have
711
        # different dimensions. Be careful!
712
        return self.softmax.categorical_cross_entropy(
0 ignored issues
show
The keyword extra_ndim does not seem to exist for the method call.
Loading history...
713
            outputs, readouts, extra_ndim=readouts.ndim - 2)
714
715
    @application
716
    def initial_outputs(self, batch_size):
717
        return self.initial_output * tensor.ones((batch_size,), dtype='int64')
718
719
    def get_dim(self, name):
720
        if name == 'outputs':
721
            return 0
722
        return super(SoftmaxEmitter, self).get_dim(name)
723
724
725
class TrivialFeedback(AbstractFeedback):
726
    """A feedback brick for the case when readout are outputs."""
727
    @lazy(allocation=['output_dim'])
728
    def __init__(self, output_dim, **kwargs):
729
        super(TrivialFeedback, self).__init__(**kwargs)
730
        self.output_dim = output_dim
731
732
    @application(outputs=['feedback'])
733
    def feedback(self, outputs):
734
        return outputs
735
736
    def get_dim(self, name):
737
        if name == 'feedback':
738
            return self.output_dim
739
        return super(TrivialFeedback, self).get_dim(name)
740
741
742
class LookupFeedback(AbstractFeedback, Initializable):
743
    """A feedback brick for the case when readout are integers.
744
745
    Stores and retrieves distributed representations of integers.
746
747
    """
748
    def __init__(self, num_outputs=None, feedback_dim=None, **kwargs):
749
        self.num_outputs = num_outputs
750
        self.feedback_dim = feedback_dim
751
752
        self.lookup = LookupTable(num_outputs, feedback_dim)
753
        children = [self.lookup]
754
        kwargs.setdefault('children', []).extend(children)
755
        super(LookupFeedback, self).__init__(**kwargs)
756
757
    def _push_allocation_config(self):
758
        self.lookup.length = self.num_outputs
759
        self.lookup.dim = self.feedback_dim
760
761
    @application
762
    def feedback(self, outputs):
763
        assert self.output_dim == 0
0 ignored issues
show
The Instance of LookupFeedback does not seem to have a member named output_dim.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
764
        return self.lookup.apply(outputs)
765
766
    def get_dim(self, name):
767
        if name == 'feedback':
768
            return self.feedback_dim
769
        return super(LookupFeedback, self).get_dim(name)
770
771
772
class FakeAttentionRecurrent(AbstractAttentionRecurrent, Initializable):
773
    """Adds fake attention interface to a transition.
774
775
    :class:`BaseSequenceGenerator` requires its transition brick to support
776
    :class:`~blocks.bricks.attention.AbstractAttentionRecurrent` interface,
777
    that is to have an embedded attention mechanism.  For the cases when no
778
    attention is required (e.g.  language modeling or encoder-decoder
779
    models), :class:`FakeAttentionRecurrent` is used to wrap a usual
780
    recurrent brick. The resulting brick has no glimpses and simply
781
    passes all states and contexts to the wrapped one.
782
783
    .. todo::
784
785
        Get rid of this brick and support attention-less transitions
786
        in :class:`BaseSequenceGenerator`.
787
788
    """
789
    def __init__(self, transition, **kwargs):
790
        self.transition = transition
791
792
        self.state_names = transition.apply.states
793
        self.context_names = transition.apply.contexts
794
        self.glimpse_names = []
795
796
        children = [self.transition]
797
        kwargs.setdefault('children', []).extend(children)
798
        super(FakeAttentionRecurrent, self).__init__(**kwargs)
799
800
    @application
801
    def apply(self, *args, **kwargs):
802
        return self.transition.apply(*args, **kwargs)
803
804
    @apply.delegate
805
    def apply_delegate(self):
806
        return self.transition.apply
807
808
    @application
809
    def compute_states(self, *args, **kwargs):
810
        return self.transition.apply(iterate=False, *args, **kwargs)
811
812
    @compute_states.delegate
813
    def compute_states_delegate(self):
814
        return self.transition.apply
815
816
    @application(outputs=[])
817
    def take_glimpses(self, *args, **kwargs):
0 ignored issues
show
The argument args seems to be unused.
Loading history...
818
        return None
819
820
    @application
821
    def initial_states(self, batch_size, *args, **kwargs):
822
        return self.transition.initial_states(batch_size,
823
                                              *args, **kwargs)
824
825
    @initial_states.property('outputs')
826
    def initial_states_outputs(self):
827
        return self.transition.apply.states
828
829
    def get_dim(self, name):
830
        return self.transition.get_dim(name)
831
832
833
class SequenceGenerator(BaseSequenceGenerator):
834
    r"""A more user-friendly interface for :class:`BaseSequenceGenerator`.
835
836
    Parameters
837
    ----------
838
    readout : instance of :class:`AbstractReadout`
839
        The readout component for the sequence generator.
840
    transition : instance of :class:`.BaseRecurrent`
841
        The recurrent transition to be used in the sequence generator.
842
        Will be combined with `attention`, if that one is given.
843
    attention : object, optional
844
        The attention mechanism to be added to ``transition``,
845
        an instance of
846
        :class:`~blocks.bricks.attention.AbstractAttention`.
847
    add_contexts : bool
848
        If ``True``, the
849
        :class:`.AttentionRecurrent` wrapping the
850
        `transition` will add additional contexts for the attended and its
851
        mask.
852
    \*\*kwargs : dict
853
        All keywords arguments are passed to the base class. If `fork`
854
        keyword argument is not provided, :class:`.Fork` is created
855
        that forks all transition sequential inputs without a "mask"
856
        substring in them.
857
858
    """
859
    def __init__(self, readout, transition, attention=None,
860
                 add_contexts=True, **kwargs):
861
        normal_inputs = [name for name in transition.apply.sequences
862
                         if 'mask' not in name]
863
        kwargs.setdefault('fork', Fork(normal_inputs))
864
        if attention:
865
            transition = AttentionRecurrent(
866
                transition, attention,
867
                add_contexts=add_contexts, name="att_trans")
868
        else:
869
            transition = FakeAttentionRecurrent(transition,
870
                                                name="with_fake_attention")
871
        super(SequenceGenerator, self).__init__(
872
            readout, transition, **kwargs)
873