Completed
Push — master ( 8efd18...1c90f6 )
by Dmitry
07:46 queued 04:23
created

_compile_initial_state_and_context_computer()   A

Complexity

Conditions 3

Size

Total Lines 13

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 13
rs 9.4285
1
"""The beam search module."""
2
from collections import OrderedDict
3
from six.moves import range
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in range.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
4
5
import numpy
6
from picklable_itertools.extras import equizip
7
from theano import config, function, tensor
8
9
from blocks.bricks.sequence_generators import BaseSequenceGenerator
10
from blocks.filter import VariableFilter, get_application_call, get_brick
11
from blocks.graph import ComputationGraph
12
from blocks.roles import INPUT, OUTPUT
13
from blocks.utils import unpack
14
15
16
class BeamSearch(object):
17
    """Approximate search for the most likely sequence.
18
19
    Beam search is an approximate algorithm for finding :math:`y^* =
20
    argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are
21
    the contexts, :math:`P` is the output distribution of a
22
    :class:`.SequenceGenerator`. At each step it considers :math:`k`
23
    candidate sequence prefixes. :math:`k` is called the beam size, and the
24
    sequence are called the beam. The sequences are replaced with their
25
    :math:`k` most probable continuations, and this is repeated until
26
    end-of-line symbol is met.
27
28
    The beam search compiles quite a few Theano functions under the hood.
29
    Normally those are compiled at the first :meth:`search` call, but
30
    you can also explicitly call :meth:`compile`.
31
32
    Parameters
33
    ----------
34
    samples : :class:`~theano.Variable`
35
        An output of a sampling computation graph built by
36
        :meth:`~blocks.brick.SequenceGenerator.generate`, the one
37
        corresponding to sampled sequences.
38
39
    See Also
40
    --------
41
    :class:`.SequenceGenerator`
42
43
    Notes
44
    -----
45
    Sequence generator should use an emitter which has `probs` method
46
    e.g. :class:`SoftmaxEmitter`.
47
48
    Does not support dummy contexts so far (all the contexts must be used
49
    in the `generate` method of the sequence generator for the current code
50
    to work).
51
52
    """
53
    def __init__(self, samples):
54
        # Extracting information from the sampling computation graph
55
        self.cg = ComputationGraph(samples)
56
        self.inputs = self.cg.inputs
57
        self.generator = get_brick(samples)
58
        if not isinstance(self.generator, BaseSequenceGenerator):
59
            raise ValueError
60
        self.generate_call = get_application_call(samples)
61
        if (not self.generate_call.application ==
62
                self.generator.generate):
63
            raise ValueError
64
        self.inner_cg = ComputationGraph(self.generate_call.inner_outputs)
65
66
        # Fetching names from the sequence generator
67
        self.context_names = self.generator.generate.contexts
68
        self.state_names = self.generator.generate.states
69
70
        # Parsing the inner computation graph of sampling scan
71
        self.contexts = [
72
            VariableFilter(bricks=[self.generator],
73
                           name=name,
74
                           roles=[INPUT])(self.inner_cg)[0]
75
            for name in self.context_names]
76
        self.input_states = []
77
        # Includes only those state names that were actually used
78
        # in 'generate'
79
        self.input_state_names = []
80
        for name in self.generator.generate.states:
81
            var = VariableFilter(
82
                bricks=[self.generator], name=name,
83
                roles=[INPUT])(self.inner_cg)
84
            if var:
85
                self.input_state_names.append(name)
86
                self.input_states.append(var[0])
87
88
        self.compiled = False
89
90
    def _compile_initial_state_and_context_computer(self):
0 ignored issues
show
Coding Style Naming introduced by
The name _compile_initial_state_and_context_computer does not conform to the method naming conventions ([a-z_][a-z0-9_]{0,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
91
        initial_states = VariableFilter(
92
                            applications=[self.generator.initial_states],
93
                            roles=[OUTPUT])(self.cg)
94
        outputs = OrderedDict([(v.tag.name, v) for v in initial_states])
95
        beam_size = unpack(VariableFilter(
96
                            applications=[self.generator.initial_states],
97
                            name='batch_size')(self.cg))
98
        for name, context in equizip(self.context_names, self.contexts):
99
            outputs[name] = context
100
        outputs['beam_size'] = beam_size
101
        self.initial_state_and_context_computer = function(
0 ignored issues
show
Coding Style Naming introduced by
The name initial_state_and_context_computer does not conform to the attribute naming conventions ((([a-z_][a-z0-9_]{0,30})|(_?[A-Z]))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
102
            self.inputs, outputs, on_unused_input='ignore')
103
104
    def _compile_next_state_computer(self):
105
        next_states = [VariableFilter(bricks=[self.generator],
106
                                      name=name,
107
                                      roles=[OUTPUT])(self.inner_cg)[-1]
108
                       for name in self.state_names]
109
        next_outputs = VariableFilter(
110
            applications=[self.generator.readout.emit], roles=[OUTPUT])(
111
                self.inner_cg.variables)
112
        self.next_state_computer = function(
113
            self.contexts + self.input_states + next_outputs, next_states)
114
115
    def _compile_logprobs_computer(self):
116
        # This filtering should return identical variables
117
        # (in terms of computations) variables, and we do not care
118
        # which to use.
119
        probs = VariableFilter(
120
            applications=[self.generator.readout.emitter.probs],
121
            roles=[OUTPUT])(self.inner_cg)[0]
122
        logprobs = -tensor.log(probs)
123
        self.logprobs_computer = function(
124
            self.contexts + self.input_states, logprobs,
125
            on_unused_input='ignore')
126
127
    def compile(self):
128
        """Compile all Theano functions used."""
129
        self._compile_initial_state_and_context_computer()
130
        self._compile_next_state_computer()
131
        self._compile_logprobs_computer()
132
        self.compiled = True
133
134
    def compute_initial_states_and_contexts(self, inputs):
0 ignored issues
show
Coding Style Naming introduced by
The name compute_initial_states_and_contexts does not conform to the method naming conventions ([a-z_][a-z0-9_]{0,30}$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
135
        """Computes initial states and contexts from inputs.
136
137
        Parameters
138
        ----------
139
        inputs : dict
140
            Dictionary of input arrays.
141
142
        Returns
143
        -------
144
        A tuple containing a {name: :class:`numpy.ndarray`} dictionary of
145
        contexts ordered like `self.context_names` and a
146
        {name: :class:`numpy.ndarray`} dictionary of states ordered like
147
        `self.state_names`.
148
149
        """
150
        outputs = self.initial_state_and_context_computer(
151
            *[inputs[var] for var in self.inputs])
152
        contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names)
153
        beam_size = outputs.pop('beam_size')
154
        initial_states = outputs
155
        return contexts, initial_states, beam_size
156
157
    def compute_logprobs(self, contexts, states):
158
        """Compute log probabilities of all possible outputs.
159
160
        Parameters
161
        ----------
162
        contexts : dict
163
            A {name: :class:`numpy.ndarray`} dictionary of contexts.
164
        states : dict
165
            A {name: :class:`numpy.ndarray`} dictionary of states.
166
167
        Returns
168
        -------
169
        A :class:`numpy.ndarray` of the (beam size, number of possible
170
        outputs) shape.
171
172
        """
173
        input_states = [states[name] for name in self.input_state_names]
174
        return self.logprobs_computer(*(list(contexts.values()) +
175
                                      input_states))
176
177
    def compute_next_states(self, contexts, states, outputs):
178
        """Computes next states.
179
180
        Parameters
181
        ----------
182
        contexts : dict
183
            A {name: :class:`numpy.ndarray`} dictionary of contexts.
184
        states : dict
185
            A {name: :class:`numpy.ndarray`} dictionary of states.
186
        outputs : :class:`numpy.ndarray`
187
            A :class:`numpy.ndarray` of this step outputs.
188
189
        Returns
190
        -------
191
        A {name: numpy.array} dictionary of next states.
192
193
        """
194
        input_states = [states[name] for name in self.input_state_names]
195
        next_values = self.next_state_computer(*(list(contexts.values()) +
196
                                                 input_states + [outputs]))
197
        return OrderedDict(equizip(self.state_names, next_values))
198
199
    @staticmethod
200
    def _smallest(matrix, k, only_first_row=False):
201
        """Find k smallest elements of a matrix.
202
203
        Parameters
204
        ----------
205
        matrix : :class:`numpy.ndarray`
206
            The matrix.
207
        k : int
208
            The number of smallest elements required.
209
        only_first_row : bool, optional
210
            Consider only elements of the first row.
211
212
        Returns
213
        -------
214
        Tuple of ((row numbers, column numbers), values).
215
216
        """
217
        if only_first_row:
218
            flatten = matrix[:1, :].flatten()
219
        else:
220
            flatten = matrix.flatten()
221
        args = numpy.argpartition(flatten, k)[:k]
222
        args = args[numpy.argsort(flatten[args])]
223
        return numpy.unravel_index(args, matrix.shape), flatten[args]
224
225
    def search(self, input_values, eol_symbol, max_length,
226
               ignore_first_eol=False, as_arrays=False):
227
        """Performs beam search.
228
229
        If the beam search was not compiled, it also compiles it.
230
231
        Parameters
232
        ----------
233
        input_values : dict
234
            A {:class:`~theano.Variable`: :class:`~numpy.ndarray`}
235
            dictionary of input values. The shapes should be
236
            the same as if you ran sampling with batch size equal to
237
            `beam_size`. Put it differently, the user is responsible
238
            for duplicaling inputs necessary number of times, because
239
            this class has insufficient information to do it properly.
240
        eol_symbol : int
241
            End of sequence symbol, the search stops when the symbol is
242
            generated.
243
        max_length : int
244
            Maximum sequence length, the search stops when it is reached.
245
        ignore_first_eol : bool, optional
246
            When ``True``, the end if sequence symbol generated at the
247
            first iteration are ignored. This useful when the sequence
248
            generator was trained on data with identical symbols for
249
            sequence start and sequence end.
250
        as_arrays : bool, optional
251
            If ``True``, the internal representation of search results
252
            is returned, that is a (matrix of outputs, mask,
253
            costs of all generated outputs) tuple.
254
255
        Returns
256
        -------
257
        outputs : list of lists of ints
258
            A list of the `beam_size` best sequences found in the order
259
            of decreasing likelihood.
260
        costs : list of floats
261
            A list of the costs for the `outputs`, where cost is the
262
            negative log-likelihood.
263
264
        """
265
        if not self.compiled:
266
            self.compile()
267
268
        contexts, states, beam_size = self.compute_initial_states_and_contexts(
269
            input_values)
270
271
        # This array will store all generated outputs, including those from
272
        # previous step and those from already finished sequences.
273
        all_outputs = states['outputs'][None, :]
274
        all_masks = numpy.ones_like(all_outputs, dtype=config.floatX)
275
        all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX)
276
277
        for i in range(max_length):
278
            if all_masks[-1].sum() == 0:
279
                break
280
281
            # We carefully hack values of the `logprobs` array to ensure
282
            # that all finished sequences are continued with `eos_symbol`.
283
            logprobs = self.compute_logprobs(contexts, states)
284
            next_costs = (all_costs[-1, :, None] +
285
                          logprobs * all_masks[-1, :, None])
286
            (finished,) = numpy.where(all_masks[-1] == 0)
287
            next_costs[finished, :eol_symbol] = numpy.inf
288
            next_costs[finished, eol_symbol + 1:] = numpy.inf
289
290
            # The `i == 0` is required because at the first step the beam
291
            # size is effectively only 1.
292
            (indexes, outputs), chosen_costs = self._smallest(
293
                next_costs, beam_size, only_first_row=i == 0)
294
295
            # Rearrange everything
296
            for name in states:
297
                states[name] = states[name][indexes]
298
            all_outputs = all_outputs[:, indexes]
299
            all_masks = all_masks[:, indexes]
300
            all_costs = all_costs[:, indexes]
301
302
            # Record chosen output and compute new states
303
            states.update(self.compute_next_states(contexts, states, outputs))
304
            all_outputs = numpy.vstack([all_outputs, outputs[None, :]])
305
            all_costs = numpy.vstack([all_costs, chosen_costs[None, :]])
306
            mask = outputs != eol_symbol
307
            if ignore_first_eol and i == 0:
308
                mask[:] = 1
309
            all_masks = numpy.vstack([all_masks, mask[None, :]])
310
311
        all_outputs = all_outputs[1:]
312
        all_masks = all_masks[:-1]
313
        all_costs = all_costs[1:] - all_costs[:-1]
314
        result = all_outputs, all_masks, all_costs
315
        if as_arrays:
316
            return result
317
        return self.result_to_lists(result)
318
319
    @staticmethod
320
    def result_to_lists(result):
321
        outputs, masks, costs = [array.T for array in result]
322
        outputs = [list(output[:mask.sum()])
323
                   for output, mask in equizip(outputs, masks)]
324
        costs = list(costs.T.sum(axis=0))
325
        return outputs, costs
326