1
|
|
|
"""The beam search module.""" |
2
|
|
|
from collections import OrderedDict |
3
|
|
|
from six.moves import range |
|
|
|
|
4
|
|
|
|
5
|
|
|
import numpy |
6
|
|
|
from picklable_itertools.extras import equizip |
7
|
|
|
from theano import config, function, tensor |
8
|
|
|
|
9
|
|
|
from blocks.bricks.sequence_generators import BaseSequenceGenerator |
10
|
|
|
from blocks.filter import VariableFilter, get_application_call, get_brick |
11
|
|
|
from blocks.graph import ComputationGraph |
12
|
|
|
from blocks.roles import INPUT, OUTPUT |
13
|
|
|
from blocks.utils import unpack |
14
|
|
|
|
15
|
|
|
|
16
|
|
|
class BeamSearch(object): |
17
|
|
|
"""Approximate search for the most likely sequence. |
18
|
|
|
|
19
|
|
|
Beam search is an approximate algorithm for finding :math:`y^* = |
20
|
|
|
argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are |
21
|
|
|
the contexts, :math:`P` is the output distribution of a |
22
|
|
|
:class:`.SequenceGenerator`. At each step it considers :math:`k` |
23
|
|
|
candidate sequence prefixes. :math:`k` is called the beam size, and the |
24
|
|
|
sequence are called the beam. The sequences are replaced with their |
25
|
|
|
:math:`k` most probable continuations, and this is repeated until |
26
|
|
|
end-of-line symbol is met. |
27
|
|
|
|
28
|
|
|
The beam search compiles quite a few Theano functions under the hood. |
29
|
|
|
Normally those are compiled at the first :meth:`search` call, but |
30
|
|
|
you can also explicitly call :meth:`compile`. |
31
|
|
|
|
32
|
|
|
Parameters |
33
|
|
|
---------- |
34
|
|
|
samples : :class:`~theano.Variable` |
35
|
|
|
An output of a sampling computation graph built by |
36
|
|
|
:meth:`~blocks.brick.SequenceGenerator.generate`, the one |
37
|
|
|
corresponding to sampled sequences. |
38
|
|
|
|
39
|
|
|
See Also |
40
|
|
|
-------- |
41
|
|
|
:class:`.SequenceGenerator` |
42
|
|
|
|
43
|
|
|
Notes |
44
|
|
|
----- |
45
|
|
|
Sequence generator should use an emitter which has `probs` method |
46
|
|
|
e.g. :class:`SoftmaxEmitter`. |
47
|
|
|
|
48
|
|
|
Does not support dummy contexts so far (all the contexts must be used |
49
|
|
|
in the `generate` method of the sequence generator for the current code |
50
|
|
|
to work). |
51
|
|
|
|
52
|
|
|
""" |
53
|
|
|
def __init__(self, samples): |
54
|
|
|
# Extracting information from the sampling computation graph |
55
|
|
|
self.cg = ComputationGraph(samples) |
56
|
|
|
self.inputs = self.cg.inputs |
57
|
|
|
self.generator = get_brick(samples) |
58
|
|
|
if not isinstance(self.generator, BaseSequenceGenerator): |
59
|
|
|
raise ValueError |
60
|
|
|
self.generate_call = get_application_call(samples) |
61
|
|
|
if (not self.generate_call.application == |
62
|
|
|
self.generator.generate): |
63
|
|
|
raise ValueError |
64
|
|
|
self.inner_cg = ComputationGraph(self.generate_call.inner_outputs) |
65
|
|
|
|
66
|
|
|
# Fetching names from the sequence generator |
67
|
|
|
self.context_names = self.generator.generate.contexts |
68
|
|
|
self.state_names = self.generator.generate.states |
69
|
|
|
|
70
|
|
|
# Parsing the inner computation graph of sampling scan |
71
|
|
|
self.contexts = [ |
72
|
|
|
VariableFilter(bricks=[self.generator], |
73
|
|
|
name=name, |
74
|
|
|
roles=[INPUT])(self.inner_cg)[0] |
75
|
|
|
for name in self.context_names] |
76
|
|
|
self.input_states = [] |
77
|
|
|
# Includes only those state names that were actually used |
78
|
|
|
# in 'generate' |
79
|
|
|
self.input_state_names = [] |
80
|
|
|
for name in self.generator.generate.states: |
81
|
|
|
var = VariableFilter( |
82
|
|
|
bricks=[self.generator], name=name, |
83
|
|
|
roles=[INPUT])(self.inner_cg) |
84
|
|
|
if var: |
85
|
|
|
self.input_state_names.append(name) |
86
|
|
|
self.input_states.append(var[0]) |
87
|
|
|
|
88
|
|
|
self.compiled = False |
89
|
|
|
|
90
|
|
|
def _compile_initial_state_and_context_computer(self): |
|
|
|
|
91
|
|
|
initial_states = VariableFilter( |
92
|
|
|
applications=[self.generator.initial_states], |
93
|
|
|
roles=[OUTPUT])(self.cg) |
94
|
|
|
outputs = OrderedDict([(v.tag.name, v) for v in initial_states]) |
95
|
|
|
beam_size = unpack(VariableFilter( |
96
|
|
|
applications=[self.generator.initial_states], |
97
|
|
|
name='batch_size')(self.cg)) |
98
|
|
|
for name, context in equizip(self.context_names, self.contexts): |
99
|
|
|
outputs[name] = context |
100
|
|
|
outputs['beam_size'] = beam_size |
101
|
|
|
self.initial_state_and_context_computer = function( |
|
|
|
|
102
|
|
|
self.inputs, outputs, on_unused_input='ignore') |
103
|
|
|
|
104
|
|
|
def _compile_next_state_computer(self): |
105
|
|
|
next_states = [VariableFilter(bricks=[self.generator], |
106
|
|
|
name=name, |
107
|
|
|
roles=[OUTPUT])(self.inner_cg)[-1] |
108
|
|
|
for name in self.state_names] |
109
|
|
|
next_outputs = VariableFilter( |
110
|
|
|
applications=[self.generator.readout.emit], roles=[OUTPUT])( |
111
|
|
|
self.inner_cg.variables) |
112
|
|
|
self.next_state_computer = function( |
113
|
|
|
self.contexts + self.input_states + next_outputs, next_states, |
114
|
|
|
on_unused_input='ignore') |
115
|
|
|
|
116
|
|
|
def _compile_logprobs_computer(self): |
117
|
|
|
# This filtering should return identical variables |
118
|
|
|
# (in terms of computations) variables, and we do not care |
119
|
|
|
# which to use. |
120
|
|
|
probs = VariableFilter( |
121
|
|
|
applications=[self.generator.readout.emitter.probs], |
122
|
|
|
roles=[OUTPUT])(self.inner_cg)[0] |
123
|
|
|
logprobs = -tensor.log(probs) |
124
|
|
|
self.logprobs_computer = function( |
125
|
|
|
self.contexts + self.input_states, logprobs, |
126
|
|
|
on_unused_input='ignore') |
127
|
|
|
|
128
|
|
|
def compile(self): |
129
|
|
|
"""Compile all Theano functions used.""" |
130
|
|
|
self._compile_initial_state_and_context_computer() |
131
|
|
|
self._compile_next_state_computer() |
132
|
|
|
self._compile_logprobs_computer() |
133
|
|
|
self.compiled = True |
134
|
|
|
|
135
|
|
|
def compute_initial_states_and_contexts(self, inputs): |
|
|
|
|
136
|
|
|
"""Computes initial states and contexts from inputs. |
137
|
|
|
|
138
|
|
|
Parameters |
139
|
|
|
---------- |
140
|
|
|
inputs : dict |
141
|
|
|
Dictionary of input arrays. |
142
|
|
|
|
143
|
|
|
Returns |
144
|
|
|
------- |
145
|
|
|
A tuple containing a {name: :class:`numpy.ndarray`} dictionary of |
146
|
|
|
contexts ordered like `self.context_names` and a |
147
|
|
|
{name: :class:`numpy.ndarray`} dictionary of states ordered like |
148
|
|
|
`self.state_names`. |
149
|
|
|
|
150
|
|
|
""" |
151
|
|
|
outputs = self.initial_state_and_context_computer( |
152
|
|
|
*[inputs[var] for var in self.inputs]) |
153
|
|
|
contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names) |
154
|
|
|
beam_size = outputs.pop('beam_size') |
155
|
|
|
initial_states = outputs |
156
|
|
|
return contexts, initial_states, beam_size |
157
|
|
|
|
158
|
|
|
def compute_logprobs(self, contexts, states): |
159
|
|
|
"""Compute log probabilities of all possible outputs. |
160
|
|
|
|
161
|
|
|
Parameters |
162
|
|
|
---------- |
163
|
|
|
contexts : dict |
164
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of contexts. |
165
|
|
|
states : dict |
166
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of states. |
167
|
|
|
|
168
|
|
|
Returns |
169
|
|
|
------- |
170
|
|
|
A :class:`numpy.ndarray` of the (beam size, number of possible |
171
|
|
|
outputs) shape. |
172
|
|
|
|
173
|
|
|
""" |
174
|
|
|
input_states = [states[name] for name in self.input_state_names] |
175
|
|
|
return self.logprobs_computer(*(list(contexts.values()) + |
176
|
|
|
input_states)) |
177
|
|
|
|
178
|
|
|
def compute_next_states(self, contexts, states, outputs): |
179
|
|
|
"""Computes next states. |
180
|
|
|
|
181
|
|
|
Parameters |
182
|
|
|
---------- |
183
|
|
|
contexts : dict |
184
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of contexts. |
185
|
|
|
states : dict |
186
|
|
|
A {name: :class:`numpy.ndarray`} dictionary of states. |
187
|
|
|
outputs : :class:`numpy.ndarray` |
188
|
|
|
A :class:`numpy.ndarray` of this step outputs. |
189
|
|
|
|
190
|
|
|
Returns |
191
|
|
|
------- |
192
|
|
|
A {name: numpy.array} dictionary of next states. |
193
|
|
|
|
194
|
|
|
""" |
195
|
|
|
input_states = [states[name] for name in self.input_state_names] |
196
|
|
|
next_values = self.next_state_computer(*(list(contexts.values()) + |
197
|
|
|
input_states + [outputs])) |
198
|
|
|
return OrderedDict(equizip(self.state_names, next_values)) |
199
|
|
|
|
200
|
|
|
@staticmethod |
201
|
|
|
def _smallest(matrix, k, only_first_row=False): |
202
|
|
|
"""Find k smallest elements of a matrix. |
203
|
|
|
|
204
|
|
|
Parameters |
205
|
|
|
---------- |
206
|
|
|
matrix : :class:`numpy.ndarray` |
207
|
|
|
The matrix. |
208
|
|
|
k : int |
209
|
|
|
The number of smallest elements required. |
210
|
|
|
only_first_row : bool, optional |
211
|
|
|
Consider only elements of the first row. |
212
|
|
|
|
213
|
|
|
Returns |
214
|
|
|
------- |
215
|
|
|
Tuple of ((row numbers, column numbers), values). |
216
|
|
|
|
217
|
|
|
""" |
218
|
|
|
if only_first_row: |
219
|
|
|
flatten = matrix[:1, :].flatten() |
220
|
|
|
else: |
221
|
|
|
flatten = matrix.flatten() |
222
|
|
|
args = numpy.argpartition(flatten, k)[:k] |
223
|
|
|
args = args[numpy.argsort(flatten[args])] |
224
|
|
|
return numpy.unravel_index(args, matrix.shape), flatten[args] |
225
|
|
|
|
226
|
|
|
def search(self, input_values, eol_symbol, max_length, |
227
|
|
|
ignore_first_eol=False, as_arrays=False): |
228
|
|
|
"""Performs beam search. |
229
|
|
|
|
230
|
|
|
If the beam search was not compiled, it also compiles it. |
231
|
|
|
|
232
|
|
|
Parameters |
233
|
|
|
---------- |
234
|
|
|
input_values : dict |
235
|
|
|
A {:class:`~theano.Variable`: :class:`~numpy.ndarray`} |
236
|
|
|
dictionary of input values. The shapes should be |
237
|
|
|
the same as if you ran sampling with batch size equal to |
238
|
|
|
`beam_size`. Put it differently, the user is responsible |
239
|
|
|
for duplicaling inputs necessary number of times, because |
240
|
|
|
this class has insufficient information to do it properly. |
241
|
|
|
eol_symbol : int |
242
|
|
|
End of sequence symbol, the search stops when the symbol is |
243
|
|
|
generated. |
244
|
|
|
max_length : int |
245
|
|
|
Maximum sequence length, the search stops when it is reached. |
246
|
|
|
ignore_first_eol : bool, optional |
247
|
|
|
When ``True``, the end if sequence symbol generated at the |
248
|
|
|
first iteration are ignored. This useful when the sequence |
249
|
|
|
generator was trained on data with identical symbols for |
250
|
|
|
sequence start and sequence end. |
251
|
|
|
as_arrays : bool, optional |
252
|
|
|
If ``True``, the internal representation of search results |
253
|
|
|
is returned, that is a (matrix of outputs, mask, |
254
|
|
|
costs of all generated outputs) tuple. |
255
|
|
|
|
256
|
|
|
Returns |
257
|
|
|
------- |
258
|
|
|
outputs : list of lists of ints |
259
|
|
|
A list of the `beam_size` best sequences found in the order |
260
|
|
|
of decreasing likelihood. |
261
|
|
|
costs : list of floats |
262
|
|
|
A list of the costs for the `outputs`, where cost is the |
263
|
|
|
negative log-likelihood. |
264
|
|
|
|
265
|
|
|
""" |
266
|
|
|
if not self.compiled: |
267
|
|
|
self.compile() |
268
|
|
|
|
269
|
|
|
contexts, states, beam_size = self.compute_initial_states_and_contexts( |
270
|
|
|
input_values) |
271
|
|
|
|
272
|
|
|
# This array will store all generated outputs, including those from |
273
|
|
|
# previous step and those from already finished sequences. |
274
|
|
|
all_outputs = states['outputs'][None, :] |
275
|
|
|
all_masks = numpy.ones_like(all_outputs, dtype=config.floatX) |
276
|
|
|
all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX) |
277
|
|
|
|
278
|
|
|
for i in range(max_length): |
279
|
|
|
if all_masks[-1].sum() == 0: |
280
|
|
|
break |
281
|
|
|
|
282
|
|
|
# We carefully hack values of the `logprobs` array to ensure |
283
|
|
|
# that all finished sequences are continued with `eos_symbol`. |
284
|
|
|
logprobs = self.compute_logprobs(contexts, states) |
285
|
|
|
next_costs = (all_costs[-1, :, None] + |
286
|
|
|
logprobs * all_masks[-1, :, None]) |
287
|
|
|
(finished,) = numpy.where(all_masks[-1] == 0) |
288
|
|
|
next_costs[finished, :eol_symbol] = numpy.inf |
289
|
|
|
next_costs[finished, eol_symbol + 1:] = numpy.inf |
290
|
|
|
|
291
|
|
|
# The `i == 0` is required because at the first step the beam |
292
|
|
|
# size is effectively only 1. |
293
|
|
|
(indexes, outputs), chosen_costs = self._smallest( |
294
|
|
|
next_costs, beam_size, only_first_row=i == 0) |
295
|
|
|
|
296
|
|
|
# Rearrange everything |
297
|
|
|
for name in states: |
298
|
|
|
states[name] = states[name][indexes] |
299
|
|
|
all_outputs = all_outputs[:, indexes] |
300
|
|
|
all_masks = all_masks[:, indexes] |
301
|
|
|
all_costs = all_costs[:, indexes] |
302
|
|
|
|
303
|
|
|
# Record chosen output and compute new states |
304
|
|
|
states.update(self.compute_next_states(contexts, states, outputs)) |
305
|
|
|
all_outputs = numpy.vstack([all_outputs, outputs[None, :]]) |
306
|
|
|
all_costs = numpy.vstack([all_costs, chosen_costs[None, :]]) |
307
|
|
|
mask = outputs != eol_symbol |
308
|
|
|
if ignore_first_eol and i == 0: |
309
|
|
|
mask[:] = 1 |
310
|
|
|
all_masks = numpy.vstack([all_masks, mask[None, :]]) |
311
|
|
|
|
312
|
|
|
all_outputs = all_outputs[1:] |
313
|
|
|
all_masks = all_masks[:-1] |
314
|
|
|
all_costs = all_costs[1:] - all_costs[:-1] |
315
|
|
|
result = all_outputs, all_masks, all_costs |
316
|
|
|
if as_arrays: |
317
|
|
|
return result |
318
|
|
|
return self.result_to_lists(result) |
319
|
|
|
|
320
|
|
|
@staticmethod |
321
|
|
|
def result_to_lists(result): |
322
|
|
|
outputs, masks, costs = [array.T for array in result] |
323
|
|
|
outputs = [list(output[:int(mask.sum())]) |
324
|
|
|
for output, mask in equizip(outputs, masks)] |
325
|
|
|
costs = list(costs.T.sum(axis=0)) |
326
|
|
|
return outputs, costs |
327
|
|
|
|
It is generally discouraged to redefine built-ins as this makes code very hard to read.