|
1
|
|
|
"""The beam search module.""" |
|
2
|
|
|
from collections import OrderedDict |
|
3
|
|
|
from six.moves import range |
|
|
|
|
|
|
4
|
|
|
|
|
5
|
|
|
import numpy |
|
6
|
|
|
from picklable_itertools.extras import equizip |
|
7
|
|
|
from theano import config, function, tensor |
|
8
|
|
|
|
|
9
|
|
|
from blocks.bricks.sequence_generators import BaseSequenceGenerator |
|
10
|
|
|
from blocks.filter import VariableFilter, get_application_call, get_brick |
|
11
|
|
|
from blocks.graph import ComputationGraph |
|
12
|
|
|
from blocks.roles import INPUT, OUTPUT |
|
13
|
|
|
from blocks.utils import unpack |
|
14
|
|
|
|
|
15
|
|
|
|
|
16
|
|
|
class BeamSearch(object): |
|
17
|
|
|
"""Approximate search for the most likely sequence. |
|
18
|
|
|
|
|
19
|
|
|
Beam search is an approximate algorithm for finding :math:`y^* = |
|
20
|
|
|
argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are |
|
21
|
|
|
the contexts, :math:`P` is the output distribution of a |
|
22
|
|
|
:class:`.SequenceGenerator`. At each step it considers :math:`k` |
|
23
|
|
|
candidate sequence prefixes. :math:`k` is called the beam size, and the |
|
24
|
|
|
sequence are called the beam. The sequences are replaced with their |
|
25
|
|
|
:math:`k` most probable continuations, and this is repeated until |
|
26
|
|
|
end-of-line symbol is met. |
|
27
|
|
|
|
|
28
|
|
|
The beam search compiles quite a few Theano functions under the hood. |
|
29
|
|
|
Normally those are compiled at the first :meth:`search` call, but |
|
30
|
|
|
you can also explicitly call :meth:`compile`. |
|
31
|
|
|
|
|
32
|
|
|
Parameters |
|
33
|
|
|
---------- |
|
34
|
|
|
samples : :class:`~theano.Variable` |
|
35
|
|
|
An output of a sampling computation graph built by |
|
36
|
|
|
:meth:`~blocks.brick.SequenceGenerator.generate`, the one |
|
37
|
|
|
corresponding to sampled sequences. |
|
38
|
|
|
|
|
39
|
|
|
See Also |
|
40
|
|
|
-------- |
|
41
|
|
|
:class:`.SequenceGenerator` |
|
42
|
|
|
|
|
43
|
|
|
Notes |
|
44
|
|
|
----- |
|
45
|
|
|
Sequence generator should use an emitter which has `probs` method |
|
46
|
|
|
e.g. :class:`SoftmaxEmitter`. |
|
47
|
|
|
|
|
48
|
|
|
Does not support dummy contexts so far (all the contexts must be used |
|
49
|
|
|
in the `generate` method of the sequence generator for the current code |
|
50
|
|
|
to work). |
|
51
|
|
|
|
|
52
|
|
|
""" |
|
53
|
|
|
def __init__(self, samples): |
|
54
|
|
|
# Extracting information from the sampling computation graph |
|
55
|
|
|
self.cg = ComputationGraph(samples) |
|
56
|
|
|
self.inputs = self.cg.inputs |
|
57
|
|
|
self.generator = get_brick(samples) |
|
58
|
|
|
if not isinstance(self.generator, BaseSequenceGenerator): |
|
59
|
|
|
raise ValueError |
|
60
|
|
|
self.generate_call = get_application_call(samples) |
|
61
|
|
|
if (not self.generate_call.application == |
|
62
|
|
|
self.generator.generate): |
|
63
|
|
|
raise ValueError |
|
64
|
|
|
self.inner_cg = ComputationGraph(self.generate_call.inner_outputs) |
|
65
|
|
|
|
|
66
|
|
|
# Fetching names from the sequence generator |
|
67
|
|
|
self.context_names = self.generator.generate.contexts |
|
68
|
|
|
self.state_names = self.generator.generate.states |
|
69
|
|
|
|
|
70
|
|
|
# Parsing the inner computation graph of sampling scan |
|
71
|
|
|
self.contexts = [ |
|
72
|
|
|
VariableFilter(bricks=[self.generator], |
|
73
|
|
|
name=name, |
|
74
|
|
|
roles=[INPUT])(self.inner_cg)[0] |
|
75
|
|
|
for name in self.context_names] |
|
76
|
|
|
self.input_states = [] |
|
77
|
|
|
# Includes only those state names that were actually used |
|
78
|
|
|
# in 'generate' |
|
79
|
|
|
self.input_state_names = [] |
|
80
|
|
|
for name in self.generator.generate.states: |
|
81
|
|
|
var = VariableFilter( |
|
82
|
|
|
bricks=[self.generator], name=name, |
|
83
|
|
|
roles=[INPUT])(self.inner_cg) |
|
84
|
|
|
if var: |
|
85
|
|
|
self.input_state_names.append(name) |
|
86
|
|
|
self.input_states.append(var[0]) |
|
87
|
|
|
|
|
88
|
|
|
self.compiled = False |
|
89
|
|
|
|
|
90
|
|
|
def _compile_initial_state_and_context_computer(self): |
|
|
|
|
|
|
91
|
|
|
initial_states = VariableFilter( |
|
92
|
|
|
applications=[self.generator.initial_states], |
|
93
|
|
|
roles=[OUTPUT])(self.cg) |
|
94
|
|
|
outputs = OrderedDict([(v.tag.name, v) for v in initial_states]) |
|
95
|
|
|
beam_size = unpack(VariableFilter( |
|
96
|
|
|
applications=[self.generator.initial_states], |
|
97
|
|
|
name='batch_size')(self.cg)) |
|
98
|
|
|
for name, context in equizip(self.context_names, self.contexts): |
|
99
|
|
|
outputs[name] = context |
|
100
|
|
|
outputs['beam_size'] = beam_size |
|
101
|
|
|
self.initial_state_and_context_computer = function( |
|
|
|
|
|
|
102
|
|
|
self.inputs, outputs, on_unused_input='ignore') |
|
103
|
|
|
|
|
104
|
|
|
def _compile_next_state_computer(self): |
|
105
|
|
|
next_states = [VariableFilter(bricks=[self.generator], |
|
106
|
|
|
name=name, |
|
107
|
|
|
roles=[OUTPUT])(self.inner_cg)[-1] |
|
108
|
|
|
for name in self.state_names] |
|
109
|
|
|
next_outputs = VariableFilter( |
|
110
|
|
|
applications=[self.generator.readout.emit], roles=[OUTPUT])( |
|
111
|
|
|
self.inner_cg.variables) |
|
112
|
|
|
self.next_state_computer = function( |
|
113
|
|
|
self.contexts + self.input_states + next_outputs, next_states, |
|
114
|
|
|
on_unused_input='ignore') |
|
115
|
|
|
|
|
116
|
|
|
def _compile_logprobs_computer(self): |
|
117
|
|
|
# This filtering should return identical variables |
|
118
|
|
|
# (in terms of computations) variables, and we do not care |
|
119
|
|
|
# which to use. |
|
120
|
|
|
probs = VariableFilter( |
|
121
|
|
|
applications=[self.generator.readout.emitter.probs], |
|
122
|
|
|
roles=[OUTPUT])(self.inner_cg)[0] |
|
123
|
|
|
logprobs = -tensor.log(probs) |
|
124
|
|
|
self.logprobs_computer = function( |
|
125
|
|
|
self.contexts + self.input_states, logprobs, |
|
126
|
|
|
on_unused_input='ignore') |
|
127
|
|
|
|
|
128
|
|
|
def compile(self): |
|
129
|
|
|
"""Compile all Theano functions used.""" |
|
130
|
|
|
self._compile_initial_state_and_context_computer() |
|
131
|
|
|
self._compile_next_state_computer() |
|
132
|
|
|
self._compile_logprobs_computer() |
|
133
|
|
|
self.compiled = True |
|
134
|
|
|
|
|
135
|
|
|
def compute_initial_states_and_contexts(self, inputs): |
|
|
|
|
|
|
136
|
|
|
"""Computes initial states and contexts from inputs. |
|
137
|
|
|
|
|
138
|
|
|
Parameters |
|
139
|
|
|
---------- |
|
140
|
|
|
inputs : dict |
|
141
|
|
|
Dictionary of input arrays. |
|
142
|
|
|
|
|
143
|
|
|
Returns |
|
144
|
|
|
------- |
|
145
|
|
|
A tuple containing a {name: :class:`numpy.ndarray`} dictionary of |
|
146
|
|
|
contexts ordered like `self.context_names` and a |
|
147
|
|
|
{name: :class:`numpy.ndarray`} dictionary of states ordered like |
|
148
|
|
|
`self.state_names`. |
|
149
|
|
|
|
|
150
|
|
|
""" |
|
151
|
|
|
outputs = self.initial_state_and_context_computer( |
|
152
|
|
|
*[inputs[var] for var in self.inputs]) |
|
153
|
|
|
contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names) |
|
154
|
|
|
beam_size = outputs.pop('beam_size') |
|
155
|
|
|
initial_states = outputs |
|
156
|
|
|
return contexts, initial_states, beam_size |
|
157
|
|
|
|
|
158
|
|
|
def compute_logprobs(self, contexts, states): |
|
159
|
|
|
"""Compute log probabilities of all possible outputs. |
|
160
|
|
|
|
|
161
|
|
|
Parameters |
|
162
|
|
|
---------- |
|
163
|
|
|
contexts : dict |
|
164
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of contexts. |
|
165
|
|
|
states : dict |
|
166
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of states. |
|
167
|
|
|
|
|
168
|
|
|
Returns |
|
169
|
|
|
------- |
|
170
|
|
|
A :class:`numpy.ndarray` of the (beam size, number of possible |
|
171
|
|
|
outputs) shape. |
|
172
|
|
|
|
|
173
|
|
|
""" |
|
174
|
|
|
input_states = [states[name] for name in self.input_state_names] |
|
175
|
|
|
return self.logprobs_computer(*(list(contexts.values()) + |
|
176
|
|
|
input_states)) |
|
177
|
|
|
|
|
178
|
|
|
def compute_next_states(self, contexts, states, outputs): |
|
179
|
|
|
"""Computes next states. |
|
180
|
|
|
|
|
181
|
|
|
Parameters |
|
182
|
|
|
---------- |
|
183
|
|
|
contexts : dict |
|
184
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of contexts. |
|
185
|
|
|
states : dict |
|
186
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of states. |
|
187
|
|
|
outputs : :class:`numpy.ndarray` |
|
188
|
|
|
A :class:`numpy.ndarray` of this step outputs. |
|
189
|
|
|
|
|
190
|
|
|
Returns |
|
191
|
|
|
------- |
|
192
|
|
|
A {name: numpy.array} dictionary of next states. |
|
193
|
|
|
|
|
194
|
|
|
""" |
|
195
|
|
|
input_states = [states[name] for name in self.input_state_names] |
|
196
|
|
|
next_values = self.next_state_computer(*(list(contexts.values()) + |
|
197
|
|
|
input_states + [outputs])) |
|
198
|
|
|
return OrderedDict(equizip(self.state_names, next_values)) |
|
199
|
|
|
|
|
200
|
|
|
@staticmethod |
|
201
|
|
|
def _smallest(matrix, k, only_first_row=False): |
|
202
|
|
|
"""Find k smallest elements of a matrix. |
|
203
|
|
|
|
|
204
|
|
|
Parameters |
|
205
|
|
|
---------- |
|
206
|
|
|
matrix : :class:`numpy.ndarray` |
|
207
|
|
|
The matrix. |
|
208
|
|
|
k : int |
|
209
|
|
|
The number of smallest elements required. |
|
210
|
|
|
only_first_row : bool, optional |
|
211
|
|
|
Consider only elements of the first row. |
|
212
|
|
|
|
|
213
|
|
|
Returns |
|
214
|
|
|
------- |
|
215
|
|
|
Tuple of ((row numbers, column numbers), values). |
|
216
|
|
|
|
|
217
|
|
|
""" |
|
218
|
|
|
if only_first_row: |
|
219
|
|
|
flatten = matrix[:1, :].flatten() |
|
220
|
|
|
else: |
|
221
|
|
|
flatten = matrix.flatten() |
|
222
|
|
|
args = numpy.argpartition(flatten, k)[:k] |
|
223
|
|
|
args = args[numpy.argsort(flatten[args])] |
|
224
|
|
|
return numpy.unravel_index(args, matrix.shape), flatten[args] |
|
225
|
|
|
|
|
226
|
|
|
def search(self, input_values, eol_symbol, max_length, |
|
227
|
|
|
ignore_first_eol=False, as_arrays=False): |
|
228
|
|
|
"""Performs beam search. |
|
229
|
|
|
|
|
230
|
|
|
If the beam search was not compiled, it also compiles it. |
|
231
|
|
|
|
|
232
|
|
|
Parameters |
|
233
|
|
|
---------- |
|
234
|
|
|
input_values : dict |
|
235
|
|
|
A {:class:`~theano.Variable`: :class:`~numpy.ndarray`} |
|
236
|
|
|
dictionary of input values. The shapes should be |
|
237
|
|
|
the same as if you ran sampling with batch size equal to |
|
238
|
|
|
`beam_size`. Put it differently, the user is responsible |
|
239
|
|
|
for duplicaling inputs necessary number of times, because |
|
240
|
|
|
this class has insufficient information to do it properly. |
|
241
|
|
|
eol_symbol : int |
|
242
|
|
|
End of sequence symbol, the search stops when the symbol is |
|
243
|
|
|
generated. |
|
244
|
|
|
max_length : int |
|
245
|
|
|
Maximum sequence length, the search stops when it is reached. |
|
246
|
|
|
ignore_first_eol : bool, optional |
|
247
|
|
|
When ``True``, the end if sequence symbol generated at the |
|
248
|
|
|
first iteration are ignored. This useful when the sequence |
|
249
|
|
|
generator was trained on data with identical symbols for |
|
250
|
|
|
sequence start and sequence end. |
|
251
|
|
|
as_arrays : bool, optional |
|
252
|
|
|
If ``True``, the internal representation of search results |
|
253
|
|
|
is returned, that is a (matrix of outputs, mask, |
|
254
|
|
|
costs of all generated outputs) tuple. |
|
255
|
|
|
|
|
256
|
|
|
Returns |
|
257
|
|
|
------- |
|
258
|
|
|
outputs : list of lists of ints |
|
259
|
|
|
A list of the `beam_size` best sequences found in the order |
|
260
|
|
|
of decreasing likelihood. |
|
261
|
|
|
costs : list of floats |
|
262
|
|
|
A list of the costs for the `outputs`, where cost is the |
|
263
|
|
|
negative log-likelihood. |
|
264
|
|
|
|
|
265
|
|
|
""" |
|
266
|
|
|
if not self.compiled: |
|
267
|
|
|
self.compile() |
|
268
|
|
|
|
|
269
|
|
|
contexts, states, beam_size = self.compute_initial_states_and_contexts( |
|
270
|
|
|
input_values) |
|
271
|
|
|
|
|
272
|
|
|
# This array will store all generated outputs, including those from |
|
273
|
|
|
# previous step and those from already finished sequences. |
|
274
|
|
|
all_outputs = states['outputs'][None, :] |
|
275
|
|
|
all_masks = numpy.ones_like(all_outputs, dtype=config.floatX) |
|
276
|
|
|
all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX) |
|
277
|
|
|
|
|
278
|
|
|
for i in range(max_length): |
|
279
|
|
|
if all_masks[-1].sum() == 0: |
|
280
|
|
|
break |
|
281
|
|
|
|
|
282
|
|
|
# We carefully hack values of the `logprobs` array to ensure |
|
283
|
|
|
# that all finished sequences are continued with `eos_symbol`. |
|
284
|
|
|
logprobs = self.compute_logprobs(contexts, states) |
|
285
|
|
|
next_costs = (all_costs[-1, :, None] + |
|
286
|
|
|
logprobs * all_masks[-1, :, None]) |
|
287
|
|
|
(finished,) = numpy.where(all_masks[-1] == 0) |
|
288
|
|
|
next_costs[finished, :eol_symbol] = numpy.inf |
|
289
|
|
|
next_costs[finished, eol_symbol + 1:] = numpy.inf |
|
290
|
|
|
|
|
291
|
|
|
# The `i == 0` is required because at the first step the beam |
|
292
|
|
|
# size is effectively only 1. |
|
293
|
|
|
(indexes, outputs), chosen_costs = self._smallest( |
|
294
|
|
|
next_costs, beam_size, only_first_row=i == 0) |
|
295
|
|
|
|
|
296
|
|
|
# Rearrange everything |
|
297
|
|
|
for name in states: |
|
298
|
|
|
states[name] = states[name][indexes] |
|
299
|
|
|
all_outputs = all_outputs[:, indexes] |
|
300
|
|
|
all_masks = all_masks[:, indexes] |
|
301
|
|
|
all_costs = all_costs[:, indexes] |
|
302
|
|
|
|
|
303
|
|
|
# Record chosen output and compute new states |
|
304
|
|
|
states.update(self.compute_next_states(contexts, states, outputs)) |
|
305
|
|
|
all_outputs = numpy.vstack([all_outputs, outputs[None, :]]) |
|
306
|
|
|
all_costs = numpy.vstack([all_costs, chosen_costs[None, :]]) |
|
307
|
|
|
mask = outputs != eol_symbol |
|
308
|
|
|
if ignore_first_eol and i == 0: |
|
309
|
|
|
mask[:] = 1 |
|
310
|
|
|
all_masks = numpy.vstack([all_masks, mask[None, :]]) |
|
311
|
|
|
|
|
312
|
|
|
all_outputs = all_outputs[1:] |
|
313
|
|
|
all_masks = all_masks[:-1] |
|
314
|
|
|
all_costs = all_costs[1:] - all_costs[:-1] |
|
315
|
|
|
result = all_outputs, all_masks, all_costs |
|
316
|
|
|
if as_arrays: |
|
317
|
|
|
return result |
|
318
|
|
|
return self.result_to_lists(result) |
|
319
|
|
|
|
|
320
|
|
|
@staticmethod |
|
321
|
|
|
def result_to_lists(result): |
|
322
|
|
|
outputs, masks, costs = [array.T for array in result] |
|
323
|
|
|
outputs = [list(output[:int(mask.sum())]) |
|
324
|
|
|
for output, mask in equizip(outputs, masks)] |
|
325
|
|
|
costs = list(costs.T.sum(axis=0)) |
|
326
|
|
|
return outputs, costs |
|
327
|
|
|
|
It is generally discouraged to redefine built-ins as this makes code very hard to read.