1 | """Sequence generation framework. |
||
2 | |||
3 | Recurrent networks are often used to generate/model sequences. |
||
4 | Examples include language modelling, machine translation, handwriting |
||
5 | synthesis, etc.. A typical pattern in this context is that |
||
6 | sequence elements are generated one often another, and every generated |
||
7 | element is fed back into the recurrent network state. Sometimes |
||
8 | also an attention mechanism is used to condition sequence generation |
||
9 | on some structured input like another sequence or an image. |
||
10 | |||
11 | This module provides :class:`SequenceGenerator` that builds a sequence |
||
12 | generating network from three main components: |
||
13 | |||
14 | * a core recurrent transition, e.g. :class:`~blocks.bricks.recurrent.LSTM` |
||
15 | or :class:`~blocks.bricks.recurrent.GatedRecurrent` |
||
16 | |||
17 | * a readout component that can produce sequence elements using |
||
18 | the network state and the information from the attention mechanism |
||
19 | |||
20 | * an attention mechanism (see :mod:`~blocks.bricks.attention` for |
||
21 | more information) |
||
22 | |||
23 | Implementation-wise :class:`SequenceGenerator` fully relies on |
||
24 | :class:`BaseSequenceGenerator`. At the level of the latter an |
||
25 | attention is mandatory, moreover it must be a part of the recurrent |
||
26 | transition (see :class:`~blocks.bricks.attention.AttentionRecurrent`). |
||
27 | To simulate optional attention, :class:`SequenceGenerator` wraps the |
||
28 | pure recurrent network in :class:`FakeAttentionRecurrent`. |
||
29 | |||
30 | """ |
||
31 | from abc import ABCMeta, abstractmethod |
||
32 | |||
33 | from six import add_metaclass |
||
34 | from theano import tensor |
||
35 | |||
36 | from blocks.bricks import Initializable, Random, Bias, NDimensionalSoftmax |
||
37 | from blocks.bricks.base import application, Brick, lazy |
||
38 | from blocks.bricks.parallel import Fork, Merge |
||
39 | from blocks.bricks.lookup import LookupTable |
||
40 | from blocks.bricks.recurrent import recurrent |
||
41 | from blocks.bricks.attention import ( |
||
42 | AbstractAttentionRecurrent, AttentionRecurrent) |
||
43 | from blocks.roles import add_role, COST |
||
44 | from blocks.utils import dict_union, dict_subset |
||
45 | |||
46 | |||
47 | class BaseSequenceGenerator(Initializable): |
||
48 | r"""A generic sequence generator. |
||
49 | |||
50 | This class combines two components, a readout network and an |
||
51 | attention-equipped recurrent transition, into a context-dependent |
||
52 | sequence generator. Third component must be also given which |
||
53 | forks feedback from the readout network to obtain inputs for the |
||
54 | transition. |
||
55 | |||
56 | The class provides two methods: :meth:`generate` and :meth:`cost`. The |
||
57 | former is to actually generate sequences and the latter is to compute |
||
58 | the cost of generating given sequences. |
||
59 | |||
60 | The generation algorithm description follows. |
||
61 | |||
62 | **Definitions and notation:** |
||
63 | |||
64 | * States :math:`s_i` of the generator are the states of the transition |
||
65 | as specified in `transition.state_names`. |
||
66 | |||
67 | * Contexts of the generator are the contexts of the |
||
68 | transition as specified in `transition.context_names`. |
||
69 | |||
70 | * Glimpses :math:`g_i` are intermediate entities computed at every |
||
71 | generation step from states, contexts and the previous step glimpses. |
||
72 | They are computed in the transition's `apply` method when not given |
||
73 | or by explicitly calling the transition's `take_glimpses` method. The |
||
74 | set of glimpses considered is specified in |
||
75 | `transition.glimpse_names`. |
||
76 | |||
77 | * Outputs :math:`y_i` are produced at every step and form the output |
||
78 | sequence. A generation cost :math:`c_i` is assigned to each output. |
||
79 | |||
80 | **Algorithm:** |
||
81 | |||
82 | 1. Initialization. |
||
83 | |||
84 | .. math:: |
||
85 | |||
86 | y_0 = readout.initial\_outputs(contexts)\\ |
||
87 | s_0, g_0 = transition.initial\_states(contexts)\\ |
||
88 | i = 1\\ |
||
89 | |||
90 | By default all recurrent bricks from :mod:`~blocks.bricks.recurrent` |
||
91 | have trainable initial states initialized with zeros. Subclass them |
||
92 | or :class:`~blocks.bricks.recurrent.BaseRecurrent` directly to get |
||
93 | custom initial states. |
||
94 | |||
95 | 2. New glimpses are computed: |
||
96 | |||
97 | .. math:: g_i = transition.take\_glimpses( |
||
98 | s_{i-1}, g_{i-1}, contexts) |
||
99 | |||
100 | 3. A new output is generated by the readout and its cost is |
||
101 | computed: |
||
102 | |||
103 | .. math:: |
||
104 | |||
105 | f_{i-1} = readout.feedback(y_{i-1}) \\ |
||
106 | r_i = readout.readout(f_{i-1}, s_{i-1}, g_i, contexts) \\ |
||
107 | y_i = readout.emit(r_i) \\ |
||
108 | c_i = readout.cost(r_i, y_i) |
||
109 | |||
110 | Note that the *new* glimpses and the *old* states are used at this |
||
111 | step. The reason for not merging all readout methods into one is |
||
112 | to make an efficient implementation of :meth:`cost` possible. |
||
113 | |||
114 | 4. New states are computed and iteration is done: |
||
115 | |||
116 | .. math:: |
||
117 | |||
118 | f_i = readout.feedback(y_i) \\ |
||
119 | s_i = transition.compute\_states(s_{i-1}, g_i, |
||
120 | fork.apply(f_i), contexts) \\ |
||
121 | i = i + 1 |
||
122 | |||
123 | 5. Back to step 2 if the desired sequence |
||
124 | length has not been yet reached. |
||
125 | |||
126 | | A scheme of the algorithm described above follows. |
||
127 | |||
128 | .. image:: /_static/sequence_generator_scheme.png |
||
129 | :height: 500px |
||
130 | :width: 500px |
||
131 | |||
132 | .. |
||
133 | |||
134 | Parameters |
||
135 | ---------- |
||
136 | readout : instance of :class:`AbstractReadout` |
||
137 | The readout component of the sequence generator. |
||
138 | transition : instance of :class:`AbstractAttentionRecurrent` |
||
139 | The transition component of the sequence generator. |
||
140 | fork : :class:`~.bricks.Brick` |
||
141 | The brick to compute the transition's inputs from the feedback. |
||
142 | |||
143 | See Also |
||
144 | -------- |
||
145 | :class:`.Initializable` : for initialization parameters |
||
146 | |||
147 | :class:`SequenceGenerator` : more user friendly interface to this\ |
||
148 | brick |
||
149 | |||
150 | """ |
||
151 | @lazy() |
||
152 | def __init__(self, readout, transition, fork, **kwargs): |
||
153 | self.readout = readout |
||
154 | self.transition = transition |
||
155 | self.fork = fork |
||
156 | |||
157 | children = [self.readout, self.fork, self.transition] |
||
158 | kwargs.setdefault('children', []).extend(children) |
||
159 | super(BaseSequenceGenerator, self).__init__(**kwargs) |
||
160 | |||
161 | @property |
||
162 | def _state_names(self): |
||
163 | return self.transition.compute_states.outputs |
||
164 | |||
165 | @property |
||
166 | def _context_names(self): |
||
167 | return self.transition.apply.contexts |
||
168 | |||
169 | @property |
||
170 | def _glimpse_names(self): |
||
171 | return self.transition.take_glimpses.outputs |
||
172 | |||
173 | def _push_allocation_config(self): |
||
174 | # Configure readout. That involves `get_dim` requests |
||
175 | # to the transition. To make sure that it answers |
||
176 | # correctly we should finish its configuration first. |
||
177 | self.transition.push_allocation_config() |
||
178 | transition_sources = (self._state_names + self._context_names + |
||
179 | self._glimpse_names) |
||
180 | self.readout.source_dims = [self.transition.get_dim(name) |
||
181 | if name in transition_sources |
||
182 | else self.readout.get_dim(name) |
||
183 | for name in self.readout.source_names] |
||
184 | |||
185 | # Configure fork. For similar reasons as outlined above, |
||
186 | # first push `readout` configuration. |
||
187 | self.readout.push_allocation_config() |
||
188 | feedback_name, = self.readout.feedback.outputs |
||
189 | self.fork.input_dim = self.readout.get_dim(feedback_name) |
||
190 | self.fork.output_dims = self.transition.get_dims( |
||
191 | self.fork.apply.outputs) |
||
192 | |||
193 | @application |
||
194 | def cost(self, application_call, outputs, mask=None, **kwargs): |
||
195 | """Returns the average cost over the minibatch. |
||
196 | |||
197 | The cost is computed by averaging the sum of per token costs for |
||
198 | each sequence over the minibatch. |
||
199 | |||
200 | .. warning:: |
||
201 | Note that, the computed cost can be problematic when batches |
||
202 | consist of vastly different sequence lengths. |
||
203 | |||
204 | Parameters |
||
205 | ---------- |
||
206 | outputs : :class:`~tensor.TensorVariable` |
||
207 | The 3(2) dimensional tensor containing output sequences. |
||
208 | The axis 0 must stand for time, the axis 1 for the |
||
209 | position in the batch. |
||
210 | mask : :class:`~tensor.TensorVariable` |
||
211 | The binary matrix identifying fake outputs. |
||
212 | |||
213 | Returns |
||
214 | ------- |
||
215 | cost : :class:`~tensor.Variable` |
||
216 | Theano variable for cost, computed by summing over timesteps |
||
217 | and then averaging over the minibatch. |
||
218 | |||
219 | Notes |
||
220 | ----- |
||
221 | The contexts are expected as keyword arguments. |
||
222 | |||
223 | Adds average cost per sequence element `AUXILIARY` variable to |
||
224 | the computational graph with name ``per_sequence_element``. |
||
225 | |||
226 | """ |
||
227 | # Compute the sum of costs |
||
228 | costs = self.cost_matrix(outputs, mask=mask, **kwargs) |
||
229 | cost = tensor.mean(costs.sum(axis=0)) |
||
230 | add_role(cost, COST) |
||
231 | |||
232 | # Add auxiliary variable for per sequence element cost |
||
233 | application_call.add_auxiliary_variable( |
||
234 | (costs.sum() / mask.sum()) if mask is not None else costs.mean(), |
||
235 | name='per_sequence_element') |
||
236 | return cost |
||
237 | |||
238 | @application |
||
239 | def cost_matrix(self, application_call, outputs, mask=None, **kwargs): |
||
240 | """Returns generation costs for output sequences. |
||
241 | |||
242 | See Also |
||
243 | -------- |
||
244 | :meth:`cost` : Scalar cost. |
||
245 | |||
246 | """ |
||
247 | # We assume the data has axes (time, batch, features, ...) |
||
248 | batch_size = outputs.shape[1] |
||
249 | |||
250 | # Prepare input for the iterative part |
||
251 | states = dict_subset(kwargs, self._state_names, must_have=False) |
||
252 | # masks in context are optional (e.g. `attended_mask`) |
||
253 | contexts = dict_subset(kwargs, self._context_names, must_have=False) |
||
254 | feedback = self.readout.feedback(outputs) |
||
255 | inputs = self.fork.apply(feedback, as_dict=True) |
||
256 | |||
257 | # Run the recurrent network |
||
258 | results = self.transition.apply( |
||
259 | mask=mask, return_initial_states=True, as_dict=True, |
||
260 | **dict_union(inputs, states, contexts)) |
||
261 | |||
262 | # Separate the deliverables. The last states are discarded: they |
||
263 | # are not used to predict any output symbol. The initial glimpses |
||
264 | # are discarded because they are not used for prediction. |
||
265 | # Remember, glimpses are computed _before_ output stage, states are |
||
266 | # computed after. |
||
267 | states = {name: results[name][:-1] for name in self._state_names} |
||
268 | glimpses = {name: results[name][1:] for name in self._glimpse_names} |
||
269 | |||
270 | # Compute the cost |
||
271 | feedback = tensor.roll(feedback, 1, 0) |
||
272 | feedback = tensor.set_subtensor( |
||
273 | feedback[0], |
||
274 | self.readout.feedback(self.readout.initial_outputs(batch_size))) |
||
275 | readouts = self.readout.readout( |
||
276 | feedback=feedback, **dict_union(states, glimpses, contexts)) |
||
277 | costs = self.readout.cost(readouts, outputs) |
||
278 | if mask is not None: |
||
279 | costs *= mask |
||
280 | |||
281 | for name, variable in list(glimpses.items()) + list(states.items()): |
||
282 | application_call.add_auxiliary_variable( |
||
283 | variable.copy(), name=name) |
||
284 | |||
285 | # This variables can be used to initialize the initial states of the |
||
286 | # next batch using the last states of the current batch. |
||
287 | for name in self._state_names + self._glimpse_names: |
||
288 | application_call.add_auxiliary_variable( |
||
289 | results[name][-1].copy(), name=name+"_final_value") |
||
290 | |||
291 | return costs |
||
292 | |||
293 | @recurrent |
||
294 | def generate(self, outputs, **kwargs): |
||
295 | """A sequence generation step. |
||
296 | |||
297 | Parameters |
||
298 | ---------- |
||
299 | outputs : :class:`~tensor.TensorVariable` |
||
300 | The outputs from the previous step. |
||
301 | |||
302 | Notes |
||
303 | ----- |
||
304 | The contexts, previous states and glimpses are expected as keyword |
||
305 | arguments. |
||
306 | |||
307 | """ |
||
308 | states = dict_subset(kwargs, self._state_names) |
||
309 | # masks in context are optional (e.g. `attended_mask`) |
||
310 | contexts = dict_subset(kwargs, self._context_names, must_have=False) |
||
311 | glimpses = dict_subset(kwargs, self._glimpse_names) |
||
312 | |||
313 | next_glimpses = self.transition.take_glimpses( |
||
314 | as_dict=True, **dict_union(states, glimpses, contexts)) |
||
315 | next_readouts = self.readout.readout( |
||
316 | feedback=self.readout.feedback(outputs), |
||
317 | **dict_union(states, next_glimpses, contexts)) |
||
318 | next_outputs = self.readout.emit(next_readouts) |
||
319 | next_costs = self.readout.cost(next_readouts, next_outputs) |
||
320 | next_feedback = self.readout.feedback(next_outputs) |
||
321 | next_inputs = (self.fork.apply(next_feedback, as_dict=True) |
||
322 | if self.fork else {'feedback': next_feedback}) |
||
323 | next_states = self.transition.compute_states( |
||
324 | as_list=True, |
||
325 | **dict_union(next_inputs, states, next_glimpses, contexts)) |
||
326 | return (next_states + [next_outputs] + |
||
327 | list(next_glimpses.values()) + [next_costs]) |
||
328 | |||
329 | @generate.delegate |
||
330 | def generate_delegate(self): |
||
331 | return self.transition.apply |
||
332 | |||
333 | @generate.property('states') |
||
334 | def generate_states(self): |
||
335 | return self._state_names + ['outputs'] + self._glimpse_names |
||
336 | |||
337 | @generate.property('outputs') |
||
338 | def generate_outputs(self): |
||
339 | return (self._state_names + ['outputs'] + |
||
340 | self._glimpse_names + ['costs']) |
||
341 | |||
342 | def get_dim(self, name): |
||
343 | if name in (self._state_names + self._context_names + |
||
344 | self._glimpse_names): |
||
345 | return self.transition.get_dim(name) |
||
346 | elif name == 'outputs': |
||
347 | return self.readout.get_dim(name) |
||
348 | return super(BaseSequenceGenerator, self).get_dim(name) |
||
349 | |||
350 | @application |
||
351 | def initial_states(self, batch_size, *args, **kwargs): |
||
352 | # TODO: support dict of outputs for application methods |
||
353 | # to simplify this code. |
||
354 | state_dict = dict( |
||
355 | self.transition.initial_states( |
||
356 | batch_size, as_dict=True, *args, **kwargs), |
||
357 | outputs=self.readout.initial_outputs(batch_size)) |
||
358 | return [state_dict[state_name] |
||
359 | for state_name in self.generate.states] |
||
360 | |||
361 | @initial_states.property('outputs') |
||
362 | def initial_states_outputs(self): |
||
363 | return self.generate.states |
||
364 | |||
365 | |||
366 | @add_metaclass(ABCMeta) |
||
0 ignored issues
–
show
|
|||
367 | class AbstractReadout(Initializable): |
||
368 | """The interface for the readout component of a sequence generator. |
||
369 | |||
370 | The readout component of a sequence generator is a bridge between |
||
371 | the core recurrent network and the output sequence. |
||
372 | |||
373 | Parameters |
||
374 | ---------- |
||
375 | source_names : list |
||
376 | A list of the source names (outputs) that are needed for the |
||
377 | readout part e.g. ``['states']`` or |
||
378 | ``['states', 'weighted_averages']`` or ``['states', 'feedback']``. |
||
379 | readout_dim : int |
||
380 | The dimension of the readout. |
||
381 | |||
382 | Attributes |
||
383 | ---------- |
||
384 | source_names : list |
||
385 | readout_dim : int |
||
386 | |||
387 | See Also |
||
388 | -------- |
||
389 | :class:`BaseSequenceGenerator` : see how exactly a readout is used |
||
390 | |||
391 | :class:`Readout` : the typically used readout brick |
||
392 | |||
393 | """ |
||
394 | @lazy(allocation=['source_names', 'readout_dim']) |
||
395 | def __init__(self, source_names, readout_dim, **kwargs): |
||
396 | self.source_names = source_names |
||
397 | self.readout_dim = readout_dim |
||
398 | super(AbstractReadout, self).__init__(**kwargs) |
||
399 | |||
400 | @abstractmethod |
||
401 | def emit(self, readouts): |
||
402 | """Produce outputs from readouts. |
||
403 | |||
404 | Parameters |
||
405 | ---------- |
||
406 | readouts : :class:`~theano.Variable` |
||
407 | Readouts produced by the :meth:`readout` method of |
||
408 | a `(batch_size, readout_dim)` shape. |
||
409 | |||
410 | """ |
||
411 | pass |
||
412 | |||
413 | @abstractmethod |
||
414 | def cost(self, readouts, outputs): |
||
415 | """Compute generation cost of outputs given readouts. |
||
416 | |||
417 | Parameters |
||
418 | ---------- |
||
419 | readouts : :class:`~theano.Variable` |
||
420 | Readouts produced by the :meth:`readout` method |
||
421 | of a `(..., readout dim)` shape. |
||
422 | outputs : :class:`~theano.Variable` |
||
423 | Outputs whose cost should be computed. Should have as many |
||
424 | or one less dimensions compared to `readout`. If readout has |
||
425 | `n` dimensions, first `n - 1` dimensions of `outputs` should |
||
426 | match with those of `readouts`. |
||
427 | |||
428 | """ |
||
429 | pass |
||
430 | |||
431 | @abstractmethod |
||
432 | def initial_outputs(self, batch_size): |
||
433 | """Compute initial outputs for the generator's first step. |
||
434 | |||
435 | In the notation from the :class:`BaseSequenceGenerator` |
||
436 | documentation this method should compute :math:`y_0`. |
||
437 | |||
438 | """ |
||
439 | pass |
||
440 | |||
441 | @abstractmethod |
||
442 | def readout(self, **kwargs): |
||
443 | r"""Compute the readout vector from states, glimpses, etc. |
||
444 | |||
445 | Parameters |
||
446 | ---------- |
||
447 | \*\*kwargs: dict |
||
448 | Contains sequence generator states, glimpses, |
||
449 | contexts and feedback from the previous outputs. |
||
450 | |||
451 | """ |
||
452 | pass |
||
453 | |||
454 | @abstractmethod |
||
455 | def feedback(self, outputs): |
||
456 | """Feeds outputs back to be used as inputs of the transition.""" |
||
457 | pass |
||
458 | |||
459 | |||
460 | class Readout(AbstractReadout): |
||
461 | r"""Readout brick with separated emitter and feedback parts. |
||
462 | |||
463 | :class:`Readout` combines a few bits and pieces into an object |
||
464 | that can be used as the readout component in |
||
465 | :class:`BaseSequenceGenerator`. This includes an emitter brick, |
||
466 | to which :meth:`emit`, :meth:`cost` and :meth:`initial_outputs` |
||
467 | calls are delegated, a feedback brick to which :meth:`feedback` |
||
468 | functionality is delegated, and a pipeline to actually compute |
||
469 | readouts from all the sources (see the `source_names` attribute |
||
470 | of :class:`AbstractReadout`). |
||
471 | |||
472 | The readout computation pipeline is constructed from `merge` and |
||
473 | `post_merge` brick, whose responsibilites are described in the |
||
474 | respective docstrings. |
||
475 | |||
476 | Parameters |
||
477 | ---------- |
||
478 | emitter : an instance of :class:`AbstractEmitter` |
||
479 | The emitter component. |
||
480 | feedback_brick : an instance of :class:`AbstractFeedback` |
||
481 | The feedback component. |
||
482 | merge : :class:`~.bricks.Brick`, optional |
||
483 | A brick that takes the sources given in `source_names` as an input |
||
484 | and combines them into a single output. If given, `merge_prototype` |
||
485 | cannot be given. |
||
486 | merge_prototype : :class:`.FeedForward`, optional |
||
487 | If `merge` isn't given, the transformation given by |
||
488 | `merge_prototype` is applied to each input before being summed. By |
||
489 | default a :class:`.Linear` transformation without biases is used. |
||
490 | If given, `merge` cannot be given. |
||
491 | post_merge : :class:`.Feedforward`, optional |
||
492 | This transformation is applied to the merged inputs. By default |
||
493 | :class:`.Bias` is used. |
||
494 | merged_dim : int, optional |
||
495 | The input dimension of `post_merge` i.e. the output dimension of |
||
496 | `merge` (or `merge_prototype`). If not give, it is assumed to be |
||
497 | the same as `readout_dim` (i.e. `post_merge` is assumed to not |
||
498 | change dimensions). |
||
499 | \*\*kwargs : dict |
||
500 | Passed to the parent's constructor. |
||
501 | |||
502 | See Also |
||
503 | -------- |
||
504 | :class:`BaseSequenceGenerator` : see how exactly a readout is used |
||
505 | |||
506 | :class:`AbstractEmitter`, :class:`AbstractFeedback` |
||
507 | |||
508 | """ |
||
509 | def __init__(self, emitter=None, feedback_brick=None, |
||
510 | merge=None, merge_prototype=None, |
||
511 | post_merge=None, merged_dim=None, **kwargs): |
||
512 | |||
513 | if not emitter: |
||
514 | emitter = TrivialEmitter(kwargs['readout_dim']) |
||
515 | if not feedback_brick: |
||
516 | feedback_brick = TrivialFeedback(kwargs['readout_dim']) |
||
517 | if not merge: |
||
518 | merge = Merge(input_names=kwargs['source_names'], |
||
519 | prototype=merge_prototype) |
||
520 | if not post_merge: |
||
521 | post_merge = Bias(dim=kwargs['readout_dim']) |
||
522 | if not merged_dim: |
||
523 | merged_dim = kwargs['readout_dim'] |
||
524 | self.emitter = emitter |
||
525 | self.feedback_brick = feedback_brick |
||
526 | self.merge = merge |
||
527 | self.post_merge = post_merge |
||
528 | self.merged_dim = merged_dim |
||
529 | |||
530 | children = [self.emitter, self.feedback_brick, self.merge, |
||
531 | self.post_merge] |
||
532 | kwargs.setdefault('children', []).extend(children) |
||
533 | super(Readout, self).__init__(**kwargs) |
||
534 | |||
535 | def _push_allocation_config(self): |
||
536 | self.emitter.readout_dim = self.get_dim('readouts') |
||
537 | self.feedback_brick.output_dim = self.get_dim('outputs') |
||
538 | self.merge.input_names = self.source_names |
||
539 | self.merge.input_dims = self.source_dims |
||
540 | self.merge.output_dim = self.merged_dim |
||
541 | self.post_merge.input_dim = self.merged_dim |
||
542 | self.post_merge.output_dim = self.readout_dim |
||
543 | |||
544 | @application |
||
545 | def readout(self, **kwargs): |
||
546 | merged = self.merge.apply(**{name: kwargs[name] |
||
547 | for name in self.merge.input_names}) |
||
548 | merged = self.post_merge.apply(merged) |
||
549 | return merged |
||
550 | |||
551 | @application |
||
552 | def emit(self, readouts): |
||
553 | return self.emitter.emit(readouts) |
||
554 | |||
555 | @application |
||
556 | def cost(self, readouts, outputs): |
||
557 | return self.emitter.cost(readouts, outputs) |
||
558 | |||
559 | @application |
||
560 | def initial_outputs(self, batch_size): |
||
561 | return self.emitter.initial_outputs(batch_size) |
||
562 | |||
563 | @application(outputs=['feedback']) |
||
564 | def feedback(self, outputs): |
||
565 | return self.feedback_brick.feedback(outputs) |
||
566 | |||
567 | def get_dim(self, name): |
||
568 | if name == 'outputs': |
||
569 | return self.emitter.get_dim(name) |
||
570 | elif name == 'feedback': |
||
571 | return self.feedback_brick.get_dim(name) |
||
572 | elif name == 'readouts': |
||
573 | return self.readout_dim |
||
574 | return super(Readout, self).get_dim(name) |
||
575 | |||
576 | |||
577 | @add_metaclass(ABCMeta) |
||
578 | class AbstractEmitter(Brick): |
||
579 | """The interface for the emitter component of a readout. |
||
580 | |||
581 | Attributes |
||
582 | ---------- |
||
583 | readout_dim : int |
||
584 | The dimension of the readout. Is given by the |
||
585 | :class:`Readout` brick when allocation configuration |
||
586 | is pushed. |
||
587 | |||
588 | See Also |
||
589 | -------- |
||
590 | :class:`Readout` |
||
591 | |||
592 | :class:`SoftmaxEmitter` : for integer outputs |
||
593 | |||
594 | Notes |
||
595 | ----- |
||
596 | An important detail about the emitter cost is that it will be |
||
597 | evaluated with inputs of different dimensions so it has to be |
||
598 | flexible enough to handle this. The two ways in which it can be |
||
599 | applied are: |
||
600 | |||
601 | 1. In :meth:BaseSequenceGenerator.cost_matrix where it will |
||
602 | be applied to the whole sequence at once. |
||
603 | |||
604 | 2. In :meth:BaseSequenceGenerator.generate where it will be |
||
605 | applied to only one step of the sequence. |
||
606 | |||
607 | """ |
||
608 | @abstractmethod |
||
609 | def emit(self, readouts): |
||
610 | """Implements the respective method of :class:`Readout`.""" |
||
611 | pass |
||
612 | |||
613 | @abstractmethod |
||
614 | def cost(self, readouts, outputs): |
||
615 | """Implements the respective method of :class:`Readout`.""" |
||
616 | pass |
||
617 | |||
618 | @abstractmethod |
||
619 | def initial_outputs(self, batch_size): |
||
620 | """Implements the respective method of :class:`Readout`.""" |
||
621 | pass |
||
622 | |||
623 | |||
624 | @add_metaclass(ABCMeta) |
||
625 | class AbstractFeedback(Brick): |
||
626 | """The interface for the feedback component of a readout. |
||
627 | |||
628 | See Also |
||
629 | -------- |
||
630 | :class:`Readout` |
||
631 | |||
632 | :class:`LookupFeedback` for integer outputs |
||
633 | |||
634 | """ |
||
635 | @abstractmethod |
||
636 | def feedback(self, outputs): |
||
637 | """Implements the respective method of :class:`Readout`.""" |
||
638 | pass |
||
639 | |||
640 | |||
641 | class TrivialEmitter(AbstractEmitter): |
||
642 | """An emitter for the trivial case when readouts are outputs. |
||
643 | |||
644 | Parameters |
||
645 | ---------- |
||
646 | readout_dim : int |
||
647 | The dimension of the readout. |
||
648 | |||
649 | Notes |
||
650 | ----- |
||
651 | By default :meth:`cost` always returns zero tensor. |
||
652 | |||
653 | """ |
||
654 | @lazy(allocation=['readout_dim']) |
||
655 | def __init__(self, readout_dim, **kwargs): |
||
656 | super(TrivialEmitter, self).__init__(**kwargs) |
||
657 | self.readout_dim = readout_dim |
||
658 | |||
659 | @application |
||
660 | def emit(self, readouts): |
||
661 | return readouts |
||
662 | |||
663 | @application |
||
664 | def cost(self, readouts, outputs): |
||
665 | return tensor.zeros_like(outputs) |
||
666 | |||
667 | @application |
||
668 | def initial_outputs(self, batch_size): |
||
669 | return tensor.zeros((batch_size, self.readout_dim)) |
||
670 | |||
671 | def get_dim(self, name): |
||
672 | if name == 'outputs': |
||
673 | return self.readout_dim |
||
674 | return super(TrivialEmitter, self).get_dim(name) |
||
675 | |||
676 | |||
677 | class SoftmaxEmitter(AbstractEmitter, Initializable, Random): |
||
678 | """A softmax emitter for the case of integer outputs. |
||
679 | |||
680 | Interprets readout elements as energies corresponding to their indices. |
||
681 | |||
682 | Parameters |
||
683 | ---------- |
||
684 | initial_output : int or a scalar :class:`~theano.Variable` |
||
685 | The initial output. |
||
686 | |||
687 | """ |
||
688 | def __init__(self, initial_output=0, **kwargs): |
||
689 | self.initial_output = initial_output |
||
690 | self.softmax = NDimensionalSoftmax() |
||
691 | children = [self.softmax] |
||
692 | kwargs.setdefault('children', []).extend(children) |
||
693 | super(SoftmaxEmitter, self).__init__(**kwargs) |
||
694 | |||
695 | @application |
||
696 | def probs(self, readouts): |
||
697 | return self.softmax.apply(readouts, extra_ndim=readouts.ndim - 2) |
||
698 | |||
699 | @application |
||
700 | def emit(self, readouts): |
||
701 | probs = self.probs(readouts) |
||
702 | batch_size = probs.shape[0] |
||
703 | pvals_flat = probs.reshape((batch_size, -1)) |
||
704 | generated = self.theano_rng.multinomial(pvals=pvals_flat) |
||
705 | return generated.reshape(probs.shape).argmax(axis=-1) |
||
706 | |||
707 | @application |
||
708 | def cost(self, readouts, outputs): |
||
709 | # WARNING: unfortunately this application method works |
||
710 | # just fine when `readouts` and `outputs` have |
||
711 | # different dimensions. Be careful! |
||
712 | return self.softmax.categorical_cross_entropy( |
||
713 | outputs, readouts, extra_ndim=readouts.ndim - 2) |
||
714 | |||
715 | @application |
||
716 | def initial_outputs(self, batch_size): |
||
717 | return self.initial_output * tensor.ones((batch_size,), dtype='int64') |
||
718 | |||
719 | def get_dim(self, name): |
||
720 | if name == 'outputs': |
||
721 | return 0 |
||
722 | return super(SoftmaxEmitter, self).get_dim(name) |
||
723 | |||
724 | |||
725 | class TrivialFeedback(AbstractFeedback): |
||
726 | """A feedback brick for the case when readout are outputs.""" |
||
727 | @lazy(allocation=['output_dim']) |
||
728 | def __init__(self, output_dim, **kwargs): |
||
729 | super(TrivialFeedback, self).__init__(**kwargs) |
||
730 | self.output_dim = output_dim |
||
731 | |||
732 | @application(outputs=['feedback']) |
||
733 | def feedback(self, outputs): |
||
734 | return outputs |
||
735 | |||
736 | def get_dim(self, name): |
||
737 | if name == 'feedback': |
||
738 | return self.output_dim |
||
739 | return super(TrivialFeedback, self).get_dim(name) |
||
740 | |||
741 | |||
742 | class LookupFeedback(AbstractFeedback, Initializable): |
||
743 | """A feedback brick for the case when readout are integers. |
||
744 | |||
745 | Stores and retrieves distributed representations of integers. |
||
746 | |||
747 | """ |
||
748 | def __init__(self, num_outputs=None, feedback_dim=None, **kwargs): |
||
749 | self.num_outputs = num_outputs |
||
750 | self.feedback_dim = feedback_dim |
||
751 | |||
752 | self.lookup = LookupTable(num_outputs, feedback_dim) |
||
753 | children = [self.lookup] |
||
754 | kwargs.setdefault('children', []).extend(children) |
||
755 | super(LookupFeedback, self).__init__(**kwargs) |
||
756 | |||
757 | def _push_allocation_config(self): |
||
758 | self.lookup.length = self.num_outputs |
||
759 | self.lookup.dim = self.feedback_dim |
||
760 | |||
761 | @application |
||
762 | def feedback(self, outputs): |
||
763 | assert self.output_dim == 0 |
||
764 | return self.lookup.apply(outputs) |
||
765 | |||
766 | def get_dim(self, name): |
||
767 | if name == 'feedback': |
||
768 | return self.feedback_dim |
||
769 | return super(LookupFeedback, self).get_dim(name) |
||
770 | |||
771 | |||
772 | class FakeAttentionRecurrent(AbstractAttentionRecurrent, Initializable): |
||
773 | """Adds fake attention interface to a transition. |
||
774 | |||
775 | :class:`BaseSequenceGenerator` requires its transition brick to support |
||
776 | :class:`~blocks.bricks.attention.AbstractAttentionRecurrent` interface, |
||
777 | that is to have an embedded attention mechanism. For the cases when no |
||
778 | attention is required (e.g. language modeling or encoder-decoder |
||
779 | models), :class:`FakeAttentionRecurrent` is used to wrap a usual |
||
780 | recurrent brick. The resulting brick has no glimpses and simply |
||
781 | passes all states and contexts to the wrapped one. |
||
782 | |||
783 | .. todo:: |
||
784 | |||
785 | Get rid of this brick and support attention-less transitions |
||
786 | in :class:`BaseSequenceGenerator`. |
||
787 | |||
788 | """ |
||
789 | def __init__(self, transition, **kwargs): |
||
790 | self.transition = transition |
||
791 | |||
792 | self.state_names = transition.apply.states |
||
793 | self.context_names = transition.apply.contexts |
||
794 | self.glimpse_names = [] |
||
795 | |||
796 | children = [self.transition] |
||
797 | kwargs.setdefault('children', []).extend(children) |
||
798 | super(FakeAttentionRecurrent, self).__init__(**kwargs) |
||
799 | |||
800 | @application |
||
801 | def apply(self, *args, **kwargs): |
||
802 | return self.transition.apply(*args, **kwargs) |
||
803 | |||
804 | @apply.delegate |
||
805 | def apply_delegate(self): |
||
806 | return self.transition.apply |
||
807 | |||
808 | @application |
||
809 | def compute_states(self, *args, **kwargs): |
||
810 | return self.transition.apply(iterate=False, *args, **kwargs) |
||
811 | |||
812 | @compute_states.delegate |
||
813 | def compute_states_delegate(self): |
||
814 | return self.transition.apply |
||
815 | |||
816 | @application(outputs=[]) |
||
817 | def take_glimpses(self, *args, **kwargs): |
||
818 | return None |
||
819 | |||
820 | @application |
||
821 | def initial_states(self, batch_size, *args, **kwargs): |
||
822 | return self.transition.initial_states(batch_size, |
||
823 | *args, **kwargs) |
||
824 | |||
825 | @initial_states.property('outputs') |
||
826 | def initial_states_outputs(self): |
||
827 | return self.transition.apply.states |
||
828 | |||
829 | def get_dim(self, name): |
||
830 | return self.transition.get_dim(name) |
||
831 | |||
832 | |||
833 | class SequenceGenerator(BaseSequenceGenerator): |
||
834 | r"""A more user-friendly interface for :class:`BaseSequenceGenerator`. |
||
835 | |||
836 | Parameters |
||
837 | ---------- |
||
838 | readout : instance of :class:`AbstractReadout` |
||
839 | The readout component for the sequence generator. |
||
840 | transition : instance of :class:`.BaseRecurrent` |
||
841 | The recurrent transition to be used in the sequence generator. |
||
842 | Will be combined with `attention`, if that one is given. |
||
843 | attention : object, optional |
||
844 | The attention mechanism to be added to ``transition``, |
||
845 | an instance of |
||
846 | :class:`~blocks.bricks.attention.AbstractAttention`. |
||
847 | add_contexts : bool |
||
848 | If ``True``, the |
||
849 | :class:`.AttentionRecurrent` wrapping the |
||
850 | `transition` will add additional contexts for the attended and its |
||
851 | mask. |
||
852 | \*\*kwargs : dict |
||
853 | All keywords arguments are passed to the base class. If `fork` |
||
854 | keyword argument is not provided, :class:`.Fork` is created |
||
855 | that forks all transition sequential inputs without a "mask" |
||
856 | substring in them. |
||
857 | |||
858 | """ |
||
859 | def __init__(self, readout, transition, attention=None, |
||
860 | add_contexts=True, **kwargs): |
||
861 | normal_inputs = [name for name in transition.apply.sequences |
||
862 | if 'mask' not in name] |
||
863 | kwargs.setdefault('fork', Fork(normal_inputs)) |
||
864 | if attention: |
||
865 | transition = AttentionRecurrent( |
||
866 | transition, attention, |
||
867 | add_contexts=add_contexts, name="att_trans") |
||
868 | else: |
||
869 | transition = FakeAttentionRecurrent(transition, |
||
870 | name="with_fake_attention") |
||
871 | super(SequenceGenerator, self).__init__( |
||
872 | readout, transition, **kwargs) |
||
873 |
Abstract classes which are used only once can usually be inlined into the class which already uses this abstract class.