Issues (119)

blocks/search.py (3 issues)

1
"""The beam search module."""
2
from collections import OrderedDict
3
from six.moves import range
4
5
import numpy
6
from picklable_itertools.extras import equizip
7
from theano import config, function, tensor
8
9
from blocks.bricks.sequence_generators import BaseSequenceGenerator
10
from blocks.filter import VariableFilter, get_application_call, get_brick
11
from blocks.graph import ComputationGraph
12
from blocks.roles import INPUT, OUTPUT
13
from blocks.utils import unpack
14
15
16
class BeamSearch(object):
17
    """Approximate search for the most likely sequence.
18
19
    Beam search is an approximate algorithm for finding :math:`y^* =
20
    argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are
21
    the contexts, :math:`P` is the output distribution of a
22
    :class:`.SequenceGenerator`. At each step it considers :math:`k`
23
    candidate sequence prefixes. :math:`k` is called the beam size, and the
24
    sequence are called the beam. The sequences are replaced with their
25
    :math:`k` most probable continuations, and this is repeated until
26
    end-of-line symbol is met.
27
28
    The beam search compiles quite a few Theano functions under the hood.
29
    Normally those are compiled at the first :meth:`search` call, but
30
    you can also explicitly call :meth:`compile`.
31
32
    Parameters
33
    ----------
34
    samples : :class:`~theano.Variable`
35
        An output of a sampling computation graph built by
36
        :meth:`~blocks.brick.SequenceGenerator.generate`, the one
37
        corresponding to sampled sequences.
38
39
    See Also
40
    --------
41
    :class:`.SequenceGenerator`
42
43
    Notes
44
    -----
45
    Sequence generator should use an emitter which has `probs` method
46
    e.g. :class:`SoftmaxEmitter`.
47
48
    Does not support dummy contexts so far (all the contexts must be used
49
    in the `generate` method of the sequence generator for the current code
50
    to work).
51
52
    """
53
    def __init__(self, samples):
54
        # Extracting information from the sampling computation graph
55
        self.cg = ComputationGraph(samples)
56
        self.inputs = self.cg.inputs
57
        self.generator = get_brick(samples)
58
        if not isinstance(self.generator, BaseSequenceGenerator):
59
            raise ValueError
60
        self.generate_call = get_application_call(samples)
61
        if (not self.generate_call.application ==
62
                self.generator.generate):
63
            raise ValueError
64
        self.inner_cg = ComputationGraph(self.generate_call.inner_outputs)
65
66
        # Fetching names from the sequence generator
67
        self.context_names = self.generator.generate.contexts
68
        self.state_names = self.generator.generate.states
69
70
        # Parsing the inner computation graph of sampling scan
71
        self.contexts = [
72
            VariableFilter(bricks=[self.generator],
73
                           name=name,
74
                           roles=[INPUT])(self.inner_cg)[0]
75
            for name in self.context_names]
76
        self.input_states = []
77
        # Includes only those state names that were actually used
78
        # in 'generate'
79
        self.input_state_names = []
80
        for name in self.generator.generate.states:
81
            var = VariableFilter(
82
                bricks=[self.generator], name=name,
83
                roles=[INPUT])(self.inner_cg)
84
            if var:
85
                self.input_state_names.append(name)
86
                self.input_states.append(var[0])
87
88
        self.compiled = False
89
90
    def _compile_initial_state_and_context_computer(self):
0 ignored issues
show
Coding Style Naming introduced by
The name _compile_initial_state_and_context_computer does not conform to the method naming conventions ([a-z_][a-z0-9_]{0,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
91
        initial_states = VariableFilter(
92
                            applications=[self.generator.initial_states],
93
                            roles=[OUTPUT])(self.cg)
94
        outputs = OrderedDict([(v.tag.name, v) for v in initial_states])
95
        beam_size = unpack(VariableFilter(
96
                            applications=[self.generator.initial_states],
97
                            name='batch_size')(self.cg))
98
        for name, context in equizip(self.context_names, self.contexts):
99
            outputs[name] = context
100
        outputs['beam_size'] = beam_size
101
        self.initial_state_and_context_computer = function(
0 ignored issues
show
Coding Style Naming introduced by
The name initial_state_and_context_computer does not conform to the attribute naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
102
            self.inputs, outputs, on_unused_input='ignore')
103
104
    def _compile_next_state_computer(self):
105
        next_states = [VariableFilter(bricks=[self.generator],
106
                                      name=name,
107
                                      roles=[OUTPUT])(self.inner_cg)[-1]
108
                       for name in self.state_names]
109
        next_outputs = VariableFilter(
110
            applications=[self.generator.readout.emit], roles=[OUTPUT])(
111
                self.inner_cg.variables)
112
        self.next_state_computer = function(
113
            self.contexts + self.input_states + next_outputs, next_states,
114
            on_unused_input='ignore')
115
116
    def _compile_logprobs_computer(self):
117
        # This filtering should return identical variables
118
        # (in terms of computations) variables, and we do not care
119
        # which to use.
120
        probs = VariableFilter(
121
            applications=[self.generator.readout.emitter.probs],
122
            roles=[OUTPUT])(self.inner_cg)[0]
123
        logprobs = -tensor.log(probs)
124
        self.logprobs_computer = function(
125
            self.contexts + self.input_states, logprobs,
126
            on_unused_input='ignore')
127
128
    def compile(self):
129
        """Compile all Theano functions used."""
130
        self._compile_initial_state_and_context_computer()
131
        self._compile_next_state_computer()
132
        self._compile_logprobs_computer()
133
        self.compiled = True
134
135
    def compute_initial_states_and_contexts(self, inputs):
0 ignored issues
show
Coding Style Naming introduced by
The name compute_initial_states_and_contexts does not conform to the method naming conventions ([a-z_][a-z0-9_]{0,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
136
        """Computes initial states and contexts from inputs.
137
138
        Parameters
139
        ----------
140
        inputs : dict
141
            Dictionary of input arrays.
142
143
        Returns
144
        -------
145
        A tuple containing a {name: :class:`numpy.ndarray`} dictionary of
146
        contexts ordered like `self.context_names` and a
147
        {name: :class:`numpy.ndarray`} dictionary of states ordered like
148
        `self.state_names`.
149
150
        """
151
        outputs = self.initial_state_and_context_computer(
152
            *[inputs[var] for var in self.inputs])
153
        contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names)
154
        beam_size = outputs.pop('beam_size')
155
        initial_states = outputs
156
        return contexts, initial_states, beam_size
157
158
    def compute_logprobs(self, contexts, states):
159
        """Compute log probabilities of all possible outputs.
160
161
        Parameters
162
        ----------
163
        contexts : dict
164
            A {name: :class:`numpy.ndarray`} dictionary of contexts.
165
        states : dict
166
            A {name: :class:`numpy.ndarray`} dictionary of states.
167
168
        Returns
169
        -------
170
        A :class:`numpy.ndarray` of the (beam size, number of possible
171
        outputs) shape.
172
173
        """
174
        input_states = [states[name] for name in self.input_state_names]
175
        return self.logprobs_computer(*(list(contexts.values()) +
176
                                      input_states))
177
178
    def compute_next_states(self, contexts, states, outputs):
179
        """Computes next states.
180
181
        Parameters
182
        ----------
183
        contexts : dict
184
            A {name: :class:`numpy.ndarray`} dictionary of contexts.
185
        states : dict
186
            A {name: :class:`numpy.ndarray`} dictionary of states.
187
        outputs : :class:`numpy.ndarray`
188
            A :class:`numpy.ndarray` of this step outputs.
189
190
        Returns
191
        -------
192
        A {name: numpy.array} dictionary of next states.
193
194
        """
195
        input_states = [states[name] for name in self.input_state_names]
196
        next_values = self.next_state_computer(*(list(contexts.values()) +
197
                                                 input_states + [outputs]))
198
        return OrderedDict(equizip(self.state_names, next_values))
199
200
    @staticmethod
201
    def _smallest(matrix, k, only_first_row=False):
202
        """Find k smallest elements of a matrix.
203
204
        Parameters
205
        ----------
206
        matrix : :class:`numpy.ndarray`
207
            The matrix.
208
        k : int
209
            The number of smallest elements required.
210
        only_first_row : bool, optional
211
            Consider only elements of the first row.
212
213
        Returns
214
        -------
215
        Tuple of ((row numbers, column numbers), values).
216
217
        """
218
        if only_first_row:
219
            flatten = matrix[:1, :].flatten()
220
        else:
221
            flatten = matrix.flatten()
222
        args = numpy.argpartition(flatten, k)[:k]
223
        args = args[numpy.argsort(flatten[args])]
224
        return numpy.unravel_index(args, matrix.shape), flatten[args]
225
226
    def search(self, input_values, eol_symbol, max_length,
227
               ignore_first_eol=False, as_arrays=False):
228
        """Performs beam search.
229
230
        If the beam search was not compiled, it also compiles it.
231
232
        Parameters
233
        ----------
234
        input_values : dict
235
            A {:class:`~theano.Variable`: :class:`~numpy.ndarray`}
236
            dictionary of input values. The shapes should be
237
            the same as if you ran sampling with batch size equal to
238
            `beam_size`. Put it differently, the user is responsible
239
            for duplicaling inputs necessary number of times, because
240
            this class has insufficient information to do it properly.
241
        eol_symbol : int
242
            End of sequence symbol, the search stops when the symbol is
243
            generated.
244
        max_length : int
245
            Maximum sequence length, the search stops when it is reached.
246
        ignore_first_eol : bool, optional
247
            When ``True``, the end if sequence symbol generated at the
248
            first iteration are ignored. This useful when the sequence
249
            generator was trained on data with identical symbols for
250
            sequence start and sequence end.
251
        as_arrays : bool, optional
252
            If ``True``, the internal representation of search results
253
            is returned, that is a (matrix of outputs, mask,
254
            costs of all generated outputs) tuple.
255
256
        Returns
257
        -------
258
        outputs : list of lists of ints
259
            A list of the `beam_size` best sequences found in the order
260
            of decreasing likelihood.
261
        costs : list of floats
262
            A list of the costs for the `outputs`, where cost is the
263
            negative log-likelihood.
264
265
        """
266
        if not self.compiled:
267
            self.compile()
268
269
        contexts, states, beam_size = self.compute_initial_states_and_contexts(
270
            input_values)
271
272
        # This array will store all generated outputs, including those from
273
        # previous step and those from already finished sequences.
274
        all_outputs = states['outputs'][None, :]
275
        all_masks = numpy.ones_like(all_outputs, dtype=config.floatX)
276
        all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX)
277
278
        for i in range(max_length):
279
            if all_masks[-1].sum() == 0:
280
                break
281
282
            # We carefully hack values of the `logprobs` array to ensure
283
            # that all finished sequences are continued with `eos_symbol`.
284
            logprobs = self.compute_logprobs(contexts, states)
285
            next_costs = (all_costs[-1, :, None] +
286
                          logprobs * all_masks[-1, :, None])
287
            (finished,) = numpy.where(all_masks[-1] == 0)
288
            next_costs[finished, :eol_symbol] = numpy.inf
289
            next_costs[finished, eol_symbol + 1:] = numpy.inf
290
291
            # The `i == 0` is required because at the first step the beam
292
            # size is effectively only 1.
293
            (indexes, outputs), chosen_costs = self._smallest(
294
                next_costs, beam_size, only_first_row=i == 0)
295
296
            # Rearrange everything
297
            for name in states:
298
                states[name] = states[name][indexes]
299
            all_outputs = all_outputs[:, indexes]
300
            all_masks = all_masks[:, indexes]
301
            all_costs = all_costs[:, indexes]
302
303
            # Record chosen output and compute new states
304
            states.update(self.compute_next_states(contexts, states, outputs))
305
            all_outputs = numpy.vstack([all_outputs, outputs[None, :]])
306
            all_costs = numpy.vstack([all_costs, chosen_costs[None, :]])
307
            mask = outputs != eol_symbol
308
            if ignore_first_eol and i == 0:
309
                mask[:] = 1
310
            all_masks = numpy.vstack([all_masks, mask[None, :]])
311
312
        all_outputs = all_outputs[1:]
313
        all_masks = all_masks[:-1]
314
        all_costs = all_costs[1:] - all_costs[:-1]
315
        result = all_outputs, all_masks, all_costs
316
        if as_arrays:
317
            return result
318
        return self.result_to_lists(result)
319
320
    @staticmethod
321
    def result_to_lists(result):
322
        outputs, masks, costs = [array.T for array in result]
323
        outputs = [list(output[:int(mask.sum())])
324
                   for output, mask in equizip(outputs, masks)]
325
        costs = list(costs.T.sum(axis=0))
326
        return outputs, costs
327